commit a291702af92b204701fb7a43472f922af1d4fea1 Author: Constantin Ruhdorfer Date: Tue Jun 25 16:22:33 2024 +0200 Init diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ca27637 --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +src/overcooked_teacher_layout_imgs + +*~ +.venv +venv +env +!.gitkeep +tmp +.DS_Store +.idea +*.log +*.map +*.pyc +*.h5 +__pycache__/ +.pytest_cache +dist/ +**/data/ +**/logs/ +**/results/ +**/images/ +**/wandb/ +**/figures/ +**/config/wandb.json +!docs/images +src/*.egg-info +**/.ipynb_checkpoints/ +!src/minimax/config diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..6a3f28b --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "editor.codeActionsOnSave": {}, + "git.ignoreLimitWarning": true +} \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..962fee0 --- /dev/null +++ b/LICENSE @@ -0,0 +1,203 @@ + + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..3cb55f7 --- /dev/null +++ b/README.md @@ -0,0 +1,496 @@ +

The Overcooked Generalisation Challenge

+ +

+ + +This repository houses the Overcooked generalisation challange, a novel cooperative UED environment that explores the effect of generalisation on cooperative agents with a focus on zero-shot cooperation. +We built this work on top of [minimax](https://github.com/facebookresearch/minimax) (original README included below) and are inspired by many of their implementation details. + +We require Python to be above 3.9 and below 3.12, we use 3.10.12. +To install this research code use `pip install -r requirements.txt`. + +## Structure + +Our project inlcudes the following major components: + +- Overcooked UED +- Multi-Agent UED Runners +- Scripts for training and evaluations +- Holdout populations for evaluation (accesible [here](https://drive.google.com/drive/folders/11fxdhrRCSTmB7BvfqMGqdIhvJUDv_0zP?usp=share_link)) + +We highlight our additions to minimax below often with additional comments. +We choose minimax as the basis as it is tested and intended for this use case. +The project is structured as follows: + +``` +docs/ + envs/ + ... + overcooked.md (<- We document OvercookedUED here) + images/ + ... +examples/* +src/ + config/ + configs/ + maze/* + overcooked/* (<- Our configurations for all runs in the paper) + minimax/ + agents/ + ... + mappo.py (<- Our MAPPO interface for training) + config/* (<- logic related to configs, and getting commands, OvercookedUED included) + envs/ + ... + overcooked_proc/ (<- home of overcooked procedual content generation for UED) + ... + overcooked_mutators.py (<- For ACCEL) + overcooked_ood.py (<- Testing layouts (can be extended!)) + overcooked_ued.py (<- UED interface) + overcooked.py (<- Overcooked capable of being run in parallel across layouts) + models/ + ... + overcooked/ + ... + models.py (<- Models we use in the paper are defined here) + runners/* + runners_ma/* (<- multi-agent runners for Overcooked UED and potentially others) + tests/* + utils/* + arguments.py + count_params.py + evaluate_against_baseline.py + evaluate_against_population.py + evaluate_baseline_against_population.py + evaluate_from_pckl.py + evaluate.py + extract_fcp.py + train.py (<- minimax starting point, also for our work) + populations/ + fcp/* (see below) + baseline_train__${what} (Trains multiple self play agents across seeds) + eval_xpid_${what} (Evals populations, stay and random agents) + eval_xpid.sh (Evals a run based on its XPID) + extract_fcp.sh (Extracts FCP checkpoint from self-play agents) + make_cmd.sh (Extended with our work) + train_baseline_${method}_${architecture}.sh (Trains all methods in the paper) + train_maze_s5.sh + train_maze.sh +``` + +## Overcooked UED +We provide a detailed explanation of the environment in the paper. +OvercookedUED provides interfaces to both edit-based, generator-based and curator-based DCD methods. +For an overview see the figure above. + +## Mutli-Agent UED Runners +Multi-Agent runners are placed under `src/minimax/runners_ma`. +They extend the minimax runners by support for multiple agents, i.e. by carrying around hidden states etc. +Note: Our current implementation only features two agents. + +## Scripts + +Reproducability is important to us. +We thus store all important script in this repository that produce the policies discussed in the paper. +To generate a command, please use `make_cmd.sh` like so by specifying `overcooked` and the config file name: + +```bash +> ./make_cmd.sh overcooked baseline_dr_softmoe_lstm +python -m train \ +--seed=1 \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=dr \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--is_multi_agent=True \ +--verbose=False \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +... +``` + +They are named `train_baseline_${method}_${architecture}.sh` and can be found in `src`. +`${method}` specifies the DCD method and can be from {`p_accel`, `dr`, `pop_paired`, `p_plr`} which correspond to parallel ACCEL (https://arxiv.org/abs/2203.01302 & https://arxiv.org/abs/2311.12716), domain randimisation (https://arxiv.org/abs/1703.06907), population paired (https://arxiv.org/abs/2012.02096) and parallel PLR (https://arxiv.org/abs/2010.03934 & https://arxiv.org/abs/2311.12716). +`${architecture}` on the other hand corresponds to the neural network architechture employed and can be from {`lstm`, `s5`, `softmoe`}. +To use them, please set the environment variable `${WANDB_ENTITY}` to your wandb user name or specify `wandb_mode=offline`. +The scripts can be called like this: + +```bash +./train_baseline_p_plr_s5.sh $device $seed +``` + +The scripts run `src/minimax/train.py` and store their results to the configured locations (see the config jsons and the `--log_dir` flag) but usually somewhere in your home directory `~/logs/`. +There are 12 train scripts and helper scripts that run multiple variations of these after the other, i.e. like in `train_baselines_s56x9.sh` that trains all 4 DCD methods with an S5 policy: + +```bash +DEFAULTVALUE=4 +DEFAULTSEED=1 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +echo "Using device ${device} and seed ${seed}" + +./train_baseline_p_plr_s5.sh $device $seed +./train_baseline_p_accel_s5.sh $device $seed +./train_baseline_pop_paired_s5.sh $device $seed +./train_baseline_dr_s5.sh $device $seed +``` + +Evaluation is performed via scripts starting with `eval`. +One can evaluate against scripted agents `eval_stay_against_population.sh` and random ones via `eval_random_against_population.sh`. +To evaluate against a population using a trained agent use `eval_xpid_against_population.sh` with device 4 and the agents XPID `YOUR_XPID` you can use `./eval_xpid_against_population.sh 4 YOUR_XPID`. + +## Holdout populations for evaluation + +The populations can be accessed here: https://drive.google.com/drive/folders/11fxdhrRCSTmB7BvfqMGqdIhvJUDv_0zP?usp=share_link. +They need to be placed under `src/populations` to work with the provided scripts. +Alternatively -- if desired -- populations can be obtained by running `src/baseline_train__all.sh` or alternatively by using `src/baseline_train__8_seeds.sh` for the desired layout, i.e. via: + +```bash +./baseline_train__8_seeds.sh $device coord_ring_6_9 +``` + +We exclude the detailed calls here as they are too verbose. +The resulting directory structure for inlcuding the poppulations should look like the following: + +```txt +src/ + minimax + ... + populations/ + fcp/ + Overcooked-AsymmAdvantages6_9/ + 1/ + high.pkl + low.pkl + meta.json + mid.pkl + xpid.txt + 2/* + ... + 8/* + population.json + Overcooked-CoordRing6_9/* + Overcooked-CounterCircuit6_9/* + Overcooked-CrampedRoom6_9/* + Overcooked-ForcedCoord6_9/* +``` + +To work with these populations meta files point to the correct scripts. +These are included in the downloadable zip, called `population.json` (see above) and should look like this: + +```json +{ + "population_size": 24, + "1": "populations/fcp/Overcooked-AsymmAdvantages6_9/1/low.pkl", + "2": "populations/fcp/Overcooked-AsymmAdvantages6_9/1/mid.pkl", + ... + "24": "populations/fcp/Overcooked-AsymmAdvantages6_9/8/high.pkl", + "1_meta": "populations/fcp/Overcooked-AsymmAdvantages6_9/1/meta.json", + "2_meta": "populations/fcp/Overcooked-AsymmAdvantages6_9/1/meta.json", + ... + "24_meta": "populations/fcp/Overcooked-AsymmAdvantages6_9/8/meta.json" +} +``` + +They help our evaluation to keep track of the correct files to use. + +To check whether they work correctly use something along the lines of (compare the eval scripts): + +```bash +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" + +for env in "Overcooked-CoordRing6_9" "Overcooked-ForcedCoord6_9" "Overcooked-CounterCircuit6_9" "Overcooked-AsymmAdvantages6_9" "Overcooked-CrampedRoom6_9"; +do + CUDA_VISIBLE_DEVICES=${device} LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.evaluate_baseline_against_population \ + --env_names=${env} \ + --population_json="populations/fcp/${env}/population.json" \ + --n_episodes=100 \ + --is_random=True +done +``` + +## Credit the minimax authors + +For attribution in academic contexts please also cite the original work on minimax: +``` +@article{jiang2023minimax, + title={minimax: Efficient Baselines for Autocurricula in JAX}, + author={Jiang, Minqi and Dennis, Michael and Grefenstette, Edward and RocktΓ€schel, Tim}, + booktitle={Agent Learning in Open-Endedness Workshop at NeurIPS}, + year={2023} +} +``` + +The original readme is included below. + +
+ +
+ +
+ +
+ +

Original Minimax Readme

+ +

+ +

Efficient baselines for autocurricula in JAX

+ +

+ + + + +

+ +## Contents +- [Why `minimax`?](#-why-minimax) + - [Hardware-accelerated baselines](#-hardware-accelerated-baselines) +- [Install](#%EF%B8%8F-install) +- [Quick start](#-quick-start) +- [Dive deeper](#-dive-deeper) + - [Training](#training) + - [Logging](#logging) + - [Checkpointing](#checkpointing) + - [Evaluating](#evaluating) +- [Environments](#%EF%B8%8F-environments) + - [Supported environments](#supported-environments) + - [Adding environments](#adding-environments) +- [Agents](#-agents) +- [Roadmap](#-roadmap) +- [License](#-license) +- [Citation](#-citation) + +## 🐒 Why `minimax`? + +Unsupervised Environment Design (UED) is a promising approach to generating autocurricula for training robust deep reinforcement learning (RL) agents. However, existing implementations of common baselines require excessive amounts of compute. In some cases, experiments can require more than a week to complete using V100 GPUs. **This long turn-around slows the rate of research progress in autocuriculum methods**. `minimax` provides fast, [JAX-based](https://github.com/google/jax) implementations of key UED baselines, which are based on the concept of _minimax_ regret. By making use of fully-tensorized environment implementations, `minimax` baselines are fully-jittable and thus take full advantage of the hardware acceleration offered by JAX. In timing studies done on V100 GPUs and Xeon E5-2698 v4 CPUs, we find `minimax` baselines can run **over 100x faster than previous reference implementations**, like those in [facebookresearch/dcd](https://github.com/facebookresearch/dcd). + +All autocurriculum algorithms implemented in `minimax` also support multi-device training, which can be activated through a [single command line flag](#multi-device-training). Using multiple devices for training can lead to further speed ups and allows scaling these autocurriculum methods to much larger batch sizes. + + + + Shows Anuraghazra's GitHub Stats. + + +### πŸ‡ Hardware-accelerated baselines + +`minimax` includes JAX-based implementations of + +- [Domain Randomization (DR)](https://arxiv.org/abs/1703.06907) + +- [Minimax adversary](https://arxiv.org/abs/2012.02096) + +- [PAIRED](https://arxiv.org/abs/2012.02096) + +- [Population PAIRED](https://arxiv.org/abs/2012.02096) + +- [Prioritized Level Replay (PLR)](https://arxiv.org/abs/2010.03934) + +- [Robust Prioritized Level Replay (PLR$`^{\perp}`$)](https://arxiv.org/abs/2110.02439) + +- [ACCEL](https://arxiv.org/abs/2203.01302) + +Additionally, `minimax` includes two new variants of PLR and ACCEL that further reduce wall time by better leveraging the massive degree of environment parallelism enabled by JAX: + +- Parallel PLR (PLR$`^{||}`$) + +- Parallel ACCEL (ACCEL$`^{||}`$) + +In brief, these two new algorithms collect rollouts for new level evaluation, level replay, and, in the case of Parallel ACCEL, mutation evaluation, all in parallel (i.e. rather than sequentially, as done by Robust PLR and ACCEL). As a simple example for why this parallelization improves wall time, consider how Robust PLR with replay probability of `0.5` would require approximately 2x as many rollouts in order to reach the same number of RL updates as a method like DR, because updates are only performed on rollouts based on level replay. Parallelizing level replay rollouts alongside new level evaluation rollouts by using 2x the environment parallelism reduces the total number of parallel rollouts to equal the total number of updates desired, thereby matching the 1:1 rollout to update ratio of DR. The diagram below summarizes this difference. + +![Parallel DCD overview](docs/images/parallel_dcd_overview.png) + +`minimax` includes a fully-tensorized implementation of a maze environment that we call [`AMaze`](docs/envs/maze.md). This environment exactly reproduces the MiniGrid-based mazes used in previous UED studies in terms of dynamics, reward function, observation space, and action space, while running many orders of magnitude faster in wall time, with increasing environment parallelism. + + +## πŸ› οΈ Install + +1. Use a virtual environment manager like `conda` or `mamba` to create a new environment for your project: + +```bash +conda create -n minimax +conda activate minimax +``` + +2. Install `minimax` via either `pip install minimax-lib` or `pip install ued`. + +3. That's it! + +⚠️ Note that to enable hardware acceleration on GPU, you will need to make sure to install the latest version of `jax>=0.4.19` and `jaxlib>=0.4.19` that is compatible with your CUDA driver (requires minimum CUDA version of `11.8`). See [the official JAX installation guide](https://jax.readthedocs.io/en/latest/installation.html#pip-installation-gpu-cuda-installed-via-pip-easier) for detailed instructions. + +## 🏁 Quick start + +The easiest way to get started is to play with the Python notebooks in the [examples folder](examples) of this repository. We also host Colab versions of these notebooks: + +- DR [[IPython](examples/dr.ipynb), [Colab](https://colab.research.google.com/drive/1HhgQgcbt77uEtKnV1uSzDsWEMlqknEAM)] + +- PAIRED [[IPython](examples/paired.ipynb), [Colab](https://colab.research.google.com/drive/1NjMNbQ4dgn8f5rt154JKDnXmQ1yV0GbT?usp=drive_link)] + +- PLR and ACCEL*: [[IPython](examples/plr.ipynb), [Colab](https://colab.research.google.com/drive/1XqVRgcIXiMDrznMIQH7wEXjGZUdCYoG9?usp=drive_link)] + +*Depending on how the top-level flags are set, this notebook runs PLR, Robust PLR, Parallel PLR, ACCEL, or Parallel ACCEL. + +`minimax` comes with high-performing hyperparameter configurations for several algorithms, including domain randomization (DR), PAIRED, PLR, and ACCEL for 60-block mazes. You can train using these settings by first creating the training command for executing `minimax.train` using the convenience script [`minimax.config.make_cmd`](docs/make_cmd.md): + +`python -m minimax.config.make_cmd --config maze/[dr,paired,plr,accel] | pbcopy`, + +followed by pasting and executing the resulting command into the command line. + +[See the docs](docs/make_cmd.md) for `minimax.config.make_cmd` to learn more about how to use this script to generate training commands from JSON configurations. You can browse the available JSON configurations for various autocurriculum methods in the [configs folder](config/configs). + +Note that when logging and checkpointing are enabled, the main `minimax.train` script outputs this data as `logs.csv` and `checkpoint.pkl` respectively in an experiment directory located at `/`, where `log_dir` and `xpid` are arguments specified in the command. You can then evaluate the checkpoint by using `minimax.evaluate`: + +```bash +python -m minimax.evaluate \ +--seed 1 \ +--log_dir \ +--xpid_prefix \ +--env_names \ +--n_episodes \ +--results_path \ +--results_fname +``` + +Some behaviors of `minimax.evaluate` to be aware of: +- This command will search `log_dir` for all experiment directories with names matching `xpid_prefix` and evaluate the checkpoint named `.pkl`. +- `minimax.evaluate` assumes xpid values end with a unique index, so that they match the regex `.*_[0-9]+$`. +- The results will be averaged over all such checkpoints (at most one checkpoint per matching experiment folder). Using the `--xpid_prefix` argument can be useful for evaluating corresponding to the same experimental configuration with different training seeds (and thus share an xpid prefix, e.g. , , ). + +If you would like to evaluate a checkpoint for only a single experiment, specify the full experiment directory name using `--xpid` instead of using `--xpid_prefix`. + + +## All command-line arguments +| Argument | Description | +| ----------------- | -------------------------------------------------------------------------------------------------------------------------------- | +| `seed` | Random seed for evaluation | +| `log_dir` | Directory containing experiment folders | +| `xpid` | Name of experiment folder, i.e. the experiment ID | +| `xpid_prefix` | Evaluate and average results over checkpoints for experiments with experiment IDs matching this prefix (ignores `--xpid` if set) | +| `checkpoint_name` | Name of checkpoint to evaluate (in each matching experiment folder) | +| `env_names` | Number of devices over which to shard the environment batch dimension | +| `n_episodes` | Number of students in the autocurriculum | +| `agent_idxs` | Indices of student agents to evaluate (csv of indices or `*` for all indices) | +| `results_path` | Number of parallel environments | +| `results_fname` | Number of parallel trials per environment (environment) | +| `render_mode` | If set, renders the evaluation episode. Requires disabling JIT. Use `'ipython'` if rendering inside an IPython notebook. | diff --git a/docs/images/OvercookedDCD.png b/docs/images/OvercookedDCD.png new file mode 100644 index 0000000..89099ea Binary files /dev/null and b/docs/images/OvercookedDCD.png differ diff --git a/docs/images/Training6x9SmallStylised.pdf b/docs/images/Training6x9SmallStylised.pdf new file mode 100644 index 0000000..c9d78a9 Binary files /dev/null and b/docs/images/Training6x9SmallStylised.pdf differ diff --git a/docs/images/Training6x9SmallStylised.png b/docs/images/Training6x9SmallStylised.png new file mode 100644 index 0000000..113c861 Binary files /dev/null and b/docs/images/Training6x9SmallStylised.png differ diff --git a/docs/images/env_maze_overview.png b/docs/images/env_maze_overview.png new file mode 100644 index 0000000..9f0e80d Binary files /dev/null and b/docs/images/env_maze_overview.png differ diff --git a/docs/images/minimax_logo.png b/docs/images/minimax_logo.png new file mode 100644 index 0000000..ff8f278 Binary files /dev/null and b/docs/images/minimax_logo.png differ diff --git a/docs/images/minimax_speedups.png b/docs/images/minimax_speedups.png new file mode 100644 index 0000000..262139a Binary files /dev/null and b/docs/images/minimax_speedups.png differ diff --git a/docs/images/minimax_speedups_darkmode.png b/docs/images/minimax_speedups_darkmode.png new file mode 100644 index 0000000..08c49cb Binary files /dev/null and b/docs/images/minimax_speedups_darkmode.png differ diff --git a/docs/images/minimax_system_diagram.png b/docs/images/minimax_system_diagram.png new file mode 100644 index 0000000..35ba777 Binary files /dev/null and b/docs/images/minimax_system_diagram.png differ diff --git a/docs/images/parallel_dcd_overview.png b/docs/images/parallel_dcd_overview.png new file mode 100644 index 0000000..3644b82 Binary files /dev/null and b/docs/images/parallel_dcd_overview.png differ diff --git a/docs/make_cmd.md b/docs/make_cmd.md new file mode 100644 index 0000000..14ef6ed --- /dev/null +++ b/docs/make_cmd.md @@ -0,0 +1,28 @@ +# Generating commands + +The `minimax.config.make_cmd` module enables generating batches of commands from a JSON configuration file, e.g. for running array jobs with Slurm. The JSON should adhere to the following format: +- Each key is a valid command-line argument for `minimax.train`. +- Each value is a list of values for the corresponding command-line argument. Commands are generated for each combination of command-line argument values. +- Boolean values should be specified as 'True' or 'False'. +- If a value is specified as `null`, the associated command-line argument is not included in the generated command (and thus would take on the default value specified when defining the argument parser). + +You can try it out by running the following command in your project root directory: + +``` +python -m minimax.config.make_cmd --config maze/plr +``` + +The above command will create a directory called `config` in the calling directory with a subdirectory `config/maze` containing configuration files for several autocurriculum methods. + +By default, `minimax.config.make_cmd` searches for configuration files inside `config`. You can create your own JSON configuration files within `config`. If your JSON configuration is located at `config/path/to/my/json`, then you can generate commands with it by calling `minimax.config.make_cmd --config path/to/my/json`. + +## Configuring `wandb` + +If your configuration includes the argument `wandb_project`, then `minimax.config.make_cmd` will look for a JSON dictionary with your credentials at `config/wandb.json`. The expected format of this JSON file is + +```json +{ + "base_url": , + "api_key": +} +``` \ No newline at end of file diff --git a/docs/parsnip.md b/docs/parsnip.md new file mode 100644 index 0000000..ccce24f --- /dev/null +++ b/docs/parsnip.md @@ -0,0 +1,131 @@ +# `Parsnip` + +## πŸ₯• `argparse` with conditional argument groups. + +As `minimax.train` is the single point-of-entry for training, its command-line arguments can grow quickly in number with each additional autocurriculum method supported in `minimax`. This complexity arises for several reasons: + +- New components in the form of training runners, environments, agents, and models may require additional arguments +- New components may require existing arguments shared with previous components +- New components may overload the meaning of existing arguments used by other components + +We make use of a custom module called `Parsnip` to help manage the complexity of specifying and parsing command-line arguments. `Parsnip` allows the creation of named argument groups, which allows adding new arguments while explicitly separating them into name spaces. Each argument group results in its own kwarg dictionary when parsed. + +`Parsnip` directly builds on `argparse` by adding the notion of a "subparser". Here, a subparser is simply an `argparse` parser responsible for a named argument group. Subparsers enable some useful behavior: +- Arguments can be added to the top-level `Parsnip` parser or to a subparser. +- Each subparser is initialized with a `name` for its corresponding argument group. All arguments under this subparser will be contained in a nested kwarg dictionary under the key equal to `name`. +- Each subparser can be initialized with an optional `prefix`, in which case all command-line arguments added to the subparser will be prepended with the value of `prefix` (see example below), thus creating a namespace for the corresponding argument group. +- Subparsers can be added conditionally, based on the specific value of a top-level argument (with support for the wildcard `*`). +- After parsing, `Parsnip` produces a kwargs dictionary containing a key:value pair for each top-level argument and a nested kwargs dictionary, under the key `` containing the parsed arguments managed by each active subparser initialized with `prefix=`. + +Other than these details, `Parsnip`'s interface remains identical to that of `argparse`. + +## A minimal example +In this example, we assume the parser is used inside a script called `run.py`. + +```python +from util.parsnip import Parsnip + +# Create a new Parsnip parser +parser = Parsnip() + +# Add some top-level arguments (same as argparse) +parser.add_argument( + '--name', + type=str, + help='Name of my farm.') +parser.add_argument( + '--kind', + type=str, + choices=['apple', 'radish'], + help='What kind of farm I run.') +parser.add_argument( + '--n_acres', + type=str, + help='Size of my farm in acres.') + +# Create a nested argument group with a prefix +crop_subparser = parser.add_subparser(name='crop', prefix='crop') +parser.add_argument( + '--n_acres', + type=str, + help='Size of land for growing radish, in acres.') + +# Create a conditional argument group +radish_subparser = parser.add_subparser( + name='radish', + prefix='radish', + dependency={'crop': 'radish'}, + dest='crop') +radish_subparser.add_argument( + '--is_pickled' + type=str2bool, + default=False, + help='Whether my farm produces pickled radish.') + +# Create another conditional argument group +apple_subparser = parser.add_subparser( + name='apple', + prefix='apple', + dependency={'crop': 'apple'}, + dest='crop') +apple_subparser.add_argument( + '--kind' + type=str, + choices=['fuji', 'mcintosh'], + default='fuji', + help='Whether my farm produces pickled radish.') + +args = parser.parse_args() +``` + +Then running this command + +```bash +python run.py \ +--name 'Radelicious Farms' \ +--kind radish \ +--n_acres 200 \ +--crop_n_acres 150 \ +--radish_is_pickled +``` + +would produce this kwargs dictionary: + +```python +{ + 'name': 'Radelicious Farms', + 'kind': 'radish', + 'n_acres': 200, + 'crop_args': { + 'n_acres': 150, + 'is_pickled': True + } +} +``` + +Notice how the `prefix` for each subparser is appended to each argument name added to that subparser (e.g. `n_acres` became `crop_n_acres`, and `is_pickled` became `radish_is_pickled`). Also notice how the `radish_is_pickled` argument became active, as its activation conditions on `kind=radish`, as we specified when defining the `radish_subparser`. + +Likewise, running this argument + +```bash +python run.py \ +--name 'Appledores Farms' \ +--kind apple \ +--n_acres 200 \ +--crop_n_acres 150 \ +--apple_kind fuji +``` + +results in this kwargs dictionary: + +```python +{ + 'name': 'Appledores Farms', + 'kind': 'apple', + 'n_acres': 200, + 'crop_args': { + 'n_acres': 150, + 'kind': 'fuji' + } +} +``` \ No newline at end of file diff --git a/docs/train_args.md b/docs/train_args.md new file mode 100644 index 0000000..17262e4 --- /dev/null +++ b/docs/train_args.md @@ -0,0 +1,125 @@ +# Command-line usage guide for `minimax.train` + +Parsing command-line arguments is handled by [`Parsnip`](parsnip.md). + +You can quickly generate batches of training commands from a JSON configuration file using [`minimax.config.make_cmd`](make_cmd.md). + +## General arguments + +| Argument | Description | +| ----------------------- | ---------------------------------------------------------------------------------------------------- | +| `seed` | Random seed, should be unique per experimental run | +| `agent_rl_algo` | Base RL algorithm used for training (e.g. PPO) | +| `n_total_updates` | Total number of updates for the training run | +| `train_runner` | Which training runner to use, e.g. `dr`, `plr`, or `paired` | +| `n_devices` | Number of devices over which to shard the environment batch dimension | +| `n_students` | Number of students in the autocurriculum | +| `n_parallel` | Number of parallel environments | +| `n_eval` | Number of parallel trials per environment (environment batch dimension is then `n_parallel*n_eval`) | +| `n_rollout_steps` | Number of steps per rollout (used for each update cycle) | +| `lr` | Learning rate | +| `lr_final` | Final learning rate, based on linear schedule. Defaults to `None`, corresponding to no schedule. | +| `lr_anneal_steps` | Number of steps over which to linearly anneal from `lr` to `lr_final` | +| `student_value_coef` | Value loss coefficient | +| `student_entropy_coef` | Entropy bonus coefficient | +| `student_unroll_update` | Unroll multi-gradient updates this many times (can lead to speed ups) | +| `max_grad_norm` | Clip gradients beyond this magnitude | +| `adam_eps` | Value of $`\epsilon`$ numerical stability constant for Adam | +| `discount` | Discount factor $`\gamma`$ for the student's RL optimization | +| `n_unroll_rollout` | Unroll rollout scans this many times (can lead to speed ups) | + +## Logging arguments + +| Argument | Description | +| ------------------- | -------------------------------------------------------- | +| `verbose` | Random seed, should be unique per experimental run | +| `track_env_metrics` | Track per rollout batch environment metrics if `True` | +| `log_dir` | Path to directory storing all experiment folders | +| `xpid` | Unique name for experiment folder, stored in `--log_dir` | +| `log_interval` | Log training statistics every this many rollout cycles | +| `wandb_base_url` | Base API URL if logging with `wandb` | +| `wandb_api_key` | API key for `wandb` | +| `wandb_entity` | `wandb` entity associated with the experiment run | +| `wandb_project` | `wandb` project for the experiment run | +| `wandb_group` | `wandb` group for the experiment run | + +## Checkpointing arguments + +| Argument | Description | +| ---------------------- | ----------------------------------------------------------------------------- | +| `checkpoint_interval` | Random seed, should be unique per experimental run | +| `from_last_checkpoint` | Begin training from latest `checkpoint.pkl`, if any, in the experiment folder | +| `archive_interval` | Save an additional checkpoint for models trained per this many rollout cycles | + +## Evaluation arguments + +| Argument | Description | +| ----------------- | -------------------------------------------------------------------- | +| `test_env_names` | Random seed, should be unique per experimental run | +| `test_n_episodes` | Average test results over this many episodes per test environment | +| `test_agent_idxs` | Test agents at these indices (csv of indices or `*` for all indices) | + +## PPO arguments + +These arguments activate when `--agent_rl_algo=ppo`. + +| Argument | Description | +| ----------------------------- | ----------------------------------------------------------- | +| `student_ppo_n_epochs` | Random seed, should be unique per experimental run | +| `student_ppo_n_epochs` | Number of PPO epochs per update cycle | +| `student_ppo_n_minibatches` | Number of minibatches per PPO epoch | +| `student_ppo_clip_eps` | Clip coefficient for PPO | +| `student_ppo_clip_value_loss` | Perform value clipping if `True` | +| `gae_lambda` | Lambda discount factor for Generalized Advantage Estimation | + +## PAIRED arguments + +The arguments in this section activate when `--train_runner=paired`. + +| Argument | Description | +| ------------------------- | --------------------------------------------------------------------- | +| `teacher_lr` | Learning rate | +| `teacher_lr_final` | Anneal learning rate to this value (defaults to `teacher_lr`) | +| `teacher_lr_anneal_steps` | Number of steps over which to linearly anneal from `lr` to `lr_final` | +| `teacher_discount` | Discount factor, $`\gamma`$ | +| `teacher_value_loss_coef` | Value loss coefficient | +| `teacher_entropy_coef` | Entropy bonus coefficient | +| `teacher_n_unroll_update` | Unroll multi-gradient updates this many times (can lead to speed ups) | +| `ued_score` | Name of UED objective, e.g. `relative_regret` | + +These PPO-specific arguments for teacher optimization further activate when `--agent_rl_algo=ppo`. + +| Argument | Description | +| ----------------------------- | ----------------------------------------------------------- | +| `teacher_ppo_n_epochs` | Number of PPO epochs per update cycle | +| `teacher_ppo_n_minibatches` | Number of minibatches per PPO epoch | +| `teacher_ppo_clip_eps` | Clip coefficient for PPO | +| `teacher_ppo_clip_value_loss` | Perform value clipping if `True` | +| `teacher_gae_lambda` | Lambda discount factor for Generalized Advantage Estimation | + +## PLR arguments + +The arguments in this section activate when `--train_runner=paired`. + +| Argument | Description | +| ----------------------------- | ------------------------------------------------------------------------------------------------------------- | +| `ued_score` | Name of UED objective (aka PLR scoring function) | +| `plr_replay_prob` | Replay probability | +| `plr_buffer_size` | Size of level replay buffer | +| `plr_staleness_coef` | Staleness coefficient | +| `plr_temp` | Score distribution temperature | +| `plr_use_score_ranks` | Use rank-based prioritization (rather than proportional) | +| `plr_min_fill_ratio` | Only replay once level replay buffer is filled above this ratio | +| `plr_use_robust_plr` | Use robust PLR (i.e. only update policy on replay levels) | +| `plr_force_unique` | Force level replay buffer members to be unique | +| `plr_use_parallel_eval` | Use Parallel PLR or Parallel ACCEL (if `plr_mutation_fn` is set) | +| `plr_mutation_fn` | If set, PLR becomes ACCEL. Use `'default'` for default mutation operator per environment. | +| `plr_n_mutations` | Number of applications of `plr_mutation_fn` per mutation cycle. | +| `plr_mutation_criterion` | How replay levels are selected for mutation (e.g. `batch`, `easy`, `hard`). | +| `plr_mutation_subsample_size` | Number of replay levels selected for mutation according to the criterion (ignored if using `batch` criterion) | + +## Environment-specific arguments + +### Maze + +See the [`AMaze`](envs/maze.md) docs for details on how to specify [training](envs/maze.md#student-environment), [evaluation](envs/maze.md#student-environment), and [teacher-specific](envs/maze.md#teacher-environment) environment parameters via command line diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..f1de7d6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +numpy>=1.25,<1.26 +pandas==1.5.3 +jax>=0.4.19 +jaxlib>=0.4.19 +flax>=0.7.4 +optax>=0.1.7 +chex>=0.1.83 +wandb>=0.13 +ipython>=7.34.0 +GitPython>=3.1.29 \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/baseline_train__8_seeds.sh b/src/baseline_train__8_seeds.sh new file mode 100755 index 0000000..a9730f9 --- /dev/null +++ b/src/baseline_train__8_seeds.sh @@ -0,0 +1,72 @@ +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" +layout=$2 + +seed_max=8 + +for seed in `seq ${seed_max}`; +do + echo "seed is ${seed}:" + CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ + --wandb_mode=online \ + --wandb_project=overcooked-minimax-jax \ + --wandb_entity=${WANDB_ENTITY} \ + --seed=${seed} \ + --agent_rl_algo=ppo \ + --n_total_updates=1000 \ + --train_runner=dr \ + --n_devices=1 \ + --student_model_name=default_student_actor_cnn \ + --student_critic_model_name=default_student_critic_cnn \ + --env_name=Overcooked \ + --is_multi_agent=True \ + --verbose=False \ + --log_dir=~/logs/minimax \ + --log_interval=10 \ + --from_last_checkpoint=False \ + --checkpoint_interval=25 \ + --archive_interval=25 \ + --archive_init_checkpoint=False \ + --test_interval=50 \ + --n_students=1 \ + --n_parallel=100 \ + --n_eval=1 \ + --n_rollout_steps=400 \ + --lr=3e-4 \ + --lr_anneal_steps=0 \ + --max_grad_norm=0.5 \ + --adam_eps=1e-05 \ + --track_env_metrics=True \ + --discount=0.99 \ + --n_unroll_rollout=10 \ + --render=False \ + --student_gae_lambda=0.95 \ + --student_entropy_coef=0.01 \ + --student_value_loss_coef=0.5 \ + --student_n_unroll_update=5 \ + --student_ppo_n_epochs=5 \ + --student_ppo_n_minibatches=1 \ + --student_ppo_clip_eps=0.2 \ + --student_ppo_clip_value_loss=True \ + --student_hidden_dim=64 \ + --student_n_hidden_layers=3 \ + --student_n_conv_layers=3 \ + --student_n_conv_filters=32 \ + --student_n_scalar_embeddings=4 \ + --student_scalar_embed_dim=5 \ + --student_agent_kind=mappo \ + --overcooked_height=6 \ + --overcooked_width=9 \ + --overcooked_n_walls=15 \ + --overcooked_replace_wall_pos=True \ + --overcooked_sample_n_walls=True \ + --overcooked_normalize_obs=True \ + --overcooked_max_steps=400 \ + --overcooked_random_reset=False \ + --overcooked_fix_to_single_layout=${layout} \ + --n_shaped_reward_steps=3000000 \ + --test_n_episodes=10 \ + --test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ + --overcooked_test_normalize_obs=True \ + --xpid=8SEED_${seed}_dr-overcookedNonexNonewNone_fs_FIX${layout}_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr3e-5g0.99cv0.5ce0.01e5mb1l0.95_pc0.2_h64cf32fc2se5ba_re_0 +done \ No newline at end of file diff --git a/src/baseline_train__all.sh b/src/baseline_train__all.sh new file mode 100755 index 0000000..83b5a58 --- /dev/null +++ b/src/baseline_train__all.sh @@ -0,0 +1,8 @@ +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" + +./baseline_train__8_seeds.sh $device coord_ring_6_9 +./baseline_train__8_seeds.sh $device counter_circuit_6_9 +./baseline_train__8_seeds.sh $device forced_coord_6_9 +./baseline_train__8_seeds.sh $device cramped_room_6_9 +./baseline_train__8_seeds.sh $device asymm_advantages_6_9 \ No newline at end of file diff --git a/src/baseline_train__holdout_sp.sh b/src/baseline_train__holdout_sp.sh new file mode 100755 index 0000000..af2b080 --- /dev/null +++ b/src/baseline_train__holdout_sp.sh @@ -0,0 +1,71 @@ +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" + +seed=42 + +for layout in "coord_ring_6_9" "forced_coord_6_9" "cramped_room_6_9" "asymm_advantages_6_9" "counter_circuit_6_9"; +do + echo "layout is ${layout}:" + CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ + --wandb_mode=online \ + --wandb_project=overcooked-minimax-jax \ + --wandb_entity=${WANDB_ENTITY} \ + --seed=${seed} \ + --agent_rl_algo=ppo \ + --n_total_updates=1000 \ + --train_runner=dr \ + --n_devices=1 \ + --student_model_name=default_student_actor_cnn \ + --student_critic_model_name=default_student_critic_cnn \ + --env_name=Overcooked \ + --is_multi_agent=True \ + --verbose=False \ + --log_dir=~/logs/minimax \ + --log_interval=10 \ + --from_last_checkpoint=False \ + --checkpoint_interval=25 \ + --archive_interval=25 \ + --archive_init_checkpoint=False \ + --test_interval=50 \ + --n_students=1 \ + --n_parallel=100 \ + --n_eval=1 \ + --n_rollout_steps=400 \ + --lr=3e-4 \ + --lr_anneal_steps=0 \ + --max_grad_norm=0.5 \ + --adam_eps=1e-05 \ + --track_env_metrics=True \ + --discount=0.99 \ + --n_unroll_rollout=10 \ + --render=False \ + --student_gae_lambda=0.95 \ + --student_entropy_coef=0.01 \ + --student_value_loss_coef=0.5 \ + --student_n_unroll_update=5 \ + --student_ppo_n_epochs=5 \ + --student_ppo_n_minibatches=1 \ + --student_ppo_clip_eps=0.2 \ + --student_ppo_clip_value_loss=True \ + --student_hidden_dim=64 \ + --student_n_hidden_layers=3 \ + --student_n_conv_layers=3 \ + --student_n_conv_filters=32 \ + --student_n_scalar_embeddings=4 \ + --student_scalar_embed_dim=5 \ + --student_agent_kind=mappo \ + --overcooked_height=6 \ + --overcooked_width=9 \ + --overcooked_n_walls=15 \ + --overcooked_replace_wall_pos=True \ + --overcooked_sample_n_walls=True \ + --overcooked_normalize_obs=True \ + --overcooked_max_steps=400 \ + --overcooked_random_reset=False \ + --overcooked_fix_to_single_layout=${layout} \ + --n_shaped_reward_steps=3000000 \ + --test_n_episodes=10 \ + --test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ + --overcooked_test_normalize_obs=True \ + --xpid=9SEED_${seed}_dr-overcookedNonexNonewNone_fs_FIX${layout}_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr3e-5g0.99cv0.5ce0.01e5mb1l0.95_pc0.2_h64cf32fc2se5ba_re_0 +done \ No newline at end of file diff --git a/src/config/configs/maze/accel.json b/src/config/configs/maze/accel.json new file mode 100644 index 0000000..bdd7c22 --- /dev/null +++ b/src/config/configs/maze/accel.json @@ -0,0 +1,73 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [false], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [20], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.0], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [0], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/maze/dr.json b/src/config/configs/maze/dr.json new file mode 100644 index 0000000..57c7c98 --- /dev/null +++ b/src/config/configs/maze/dr.json @@ -0,0 +1,59 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/maze/paccel.json b/src/config/configs/maze/paccel.json new file mode 100644 index 0000000..10694da --- /dev/null +++ b/src/config/configs/maze/paccel.json @@ -0,0 +1,73 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [10], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.0], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [0], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/maze/paired.json b/src/config/configs/maze/paired.json new file mode 100644 index 0000000..ff0a370 --- /dev/null +++ b/src/config/configs/maze/paired.json @@ -0,0 +1,84 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["paired"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [2], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["relative_regret"], + "student_gae_lambda": [0.98], + "teacher_discount": [0.995], + "teacher_lr_anneal_steps": [0], + "teacher_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "teacher_entropy_coef": [0.05], + "teacher_value_loss_coef": [0.5], + "teacher_n_unroll_update": [5], + "teacher_ppo_n_epochs": [5], + "teacher_ppo_n_minibatches": [1], + "teacher_ppo_clip_eps": [0.2], + "teacher_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "teacher_model_name": ["default_teacher_cnn"], + "teacher_recurrent_arch": ["lstm"], + "teacher_recurrent_hidden_dim": [256], + "teacher_hidden_dim": [32], + "teacher_n_hidden_layers": [1], + "teacher_n_conv_filters": [128], + "teacher_scalar_embed_dim": [10], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [false], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "maze_ued_replace_wall_pos": [true], + "maze_ued_fixed_n_wall_steps": [true], + "maze_ued_first_wall_pos_sets_budget": [false], + "maze_ued_noise_dim": [50], + "maze_ued_n_walls": [60], + "maze_ued_set_agent_dir": [false], + "maze_ued_normalize_obs": [true], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/maze/plr.json b/src/config/configs/maze/plr.json new file mode 100644 index 0000000..229fbf2 --- /dev/null +++ b/src/config/configs/maze/plr.json @@ -0,0 +1,69 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [5e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [false], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.0], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/maze/pplr.json b/src/config/configs/maze/pplr.json new file mode 100644 index 0000000..071cab2 --- /dev/null +++ b/src/config/configs/maze/pplr.json @@ -0,0 +1,69 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.3], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.0], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/maze/s5_accel.json b/src/config/configs/maze/s5_accel.json new file mode 100644 index 0000000..f2fbf76 --- /dev/null +++ b/src/config/configs/maze/s5_accel.json @@ -0,0 +1,78 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.3], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [false], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [10], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["post"], + "student_s5_activation": ["half_glu1"], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [0], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "test_agent_idxs": ["\"*\""], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/maze/s5_dr.json b/src/config/configs/maze/s5_dr.json new file mode 100644 index 0000000..5f688c5 --- /dev/null +++ b/src/config/configs/maze/s5_dr.json @@ -0,0 +1,63 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["post"], + "student_s5_activation": ["half_glu1"], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/maze/s5_paccel.json b/src/config/configs/maze/s5_paccel.json new file mode 100644 index 0000000..d61f0c2 --- /dev/null +++ b/src/config/configs/maze/s5_paccel.json @@ -0,0 +1,77 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [1e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.3], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [20], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.0], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["post"], + "student_s5_activation": ["half_glu1"], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [0], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/maze/s5_paired.json b/src/config/configs/maze/s5_paired.json new file mode 100644 index 0000000..451bd40 --- /dev/null +++ b/src/config/configs/maze/s5_paired.json @@ -0,0 +1,94 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["paired"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [2], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["relative_regret"], + "student_gae_lambda": [0.98], + "teacher_discount": [0.995], + "teacher_lr": [0.0001], + "teacher_lr_anneal_steps": [0], + "teacher_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "teacher_entropy_coef": [0.001], + "teacher_value_loss_coef": [0.5], + "teacher_n_unroll_update": [5], + "teacher_ppo_n_epochs": [5], + "teacher_ppo_n_minibatches": [1], + "teacher_ppo_clip_eps": [0.2], + "teacher_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["post"], + "student_s5_activation": ["half_glu1"], + "teacher_model_name": ["default_teacher_cnn"], + "teacher_recurrent_arch": ["s5"], + "teacher_recurrent_hidden_dim": [256], + "teacher_hidden_dim": [32], + "teacher_n_hidden_layers": [1], + "teacher_n_conv_filters": [32], + "teacher_scalar_embed_dim": [10], + "teacher_s5_n_blocks": [2], + "teacher_s5_n_layers": [2], + "teacher_s5_layernorm_pos": ["post"], + "teacher_s5_activation": ["half_glu1"], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [false], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "maze_ued_replace_wall_pos": [true], + "maze_ued_fixed_n_wall_steps": [true], + "maze_ued_first_wall_pos_sets_budget": [false], + "maze_ued_noise_dim": [50], + "maze_ued_n_walls": [60], + "maze_ued_set_agent_dir": [false], + "maze_ued_normalize_obs": [true], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "test_agent_idxs": ["\"*\""], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/maze/s5_plr.json b/src/config/configs/maze/s5_plr.json new file mode 100644 index 0000000..05a9146 --- /dev/null +++ b/src/config/configs/maze/s5_plr.json @@ -0,0 +1,73 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.3], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [false], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/maze/s5_pplr.json b/src/config/configs/maze/s5_pplr.json new file mode 100644 index 0000000..cccdce6 --- /dev/null +++ b/src/config/configs/maze/s5_pplr.json @@ -0,0 +1,73 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.3], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["post"], + "student_s5_activation": ["half_glu1"], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline__cnn_asymm_advantages.json b/src/config/configs/overcooked/baseline__cnn_asymm_advantages.json new file mode 100644 index 0000000..6999ad4 --- /dev/null +++ b/src/config/configs/overcooked/baseline__cnn_asymm_advantages.json @@ -0,0 +1,63 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [1000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [25], + "archive_interval": [25], + "archive_init_checkpoint": [false], + "test_interval": [50], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [5e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.99], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.95], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "overcooked_fix_to_single_layout": ["asymm_advantages_6_9"], + "n_shaped_reward_steps": [5000000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline__s5_asymm_advantages.json b/src/config/configs/overcooked/baseline__s5_asymm_advantages.json new file mode 100644 index 0000000..0a7961e --- /dev/null +++ b/src/config/configs/overcooked/baseline__s5_asymm_advantages.json @@ -0,0 +1,69 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [1000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [100], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "overcooked_fix_to_single_layout": ["asymm_advantages_6_9"], + "n_shaped_reward_steps": [5000000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline__s5_coord_ring.json b/src/config/configs/overcooked/baseline__s5_coord_ring.json new file mode 100644 index 0000000..d1b1eb9 --- /dev/null +++ b/src/config/configs/overcooked/baseline__s5_coord_ring.json @@ -0,0 +1,69 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [1000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [true], + "overcooked_fix_to_single_layout": ["coord_ring_6_9"], + "n_shaped_reward_steps": [5000000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline__s5_counter_circuit.json b/src/config/configs/overcooked/baseline__s5_counter_circuit.json new file mode 100644 index 0000000..ed74216 --- /dev/null +++ b/src/config/configs/overcooked/baseline__s5_counter_circuit.json @@ -0,0 +1,69 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [1000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [true], + "overcooked_fix_to_single_layout": ["counter_circuit_6_9"], + "n_shaped_reward_steps": [5000000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline__s5_cramped_room.json b/src/config/configs/overcooked/baseline__s5_cramped_room.json new file mode 100644 index 0000000..c82f066 --- /dev/null +++ b/src/config/configs/overcooked/baseline__s5_cramped_room.json @@ -0,0 +1,69 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [1000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [true], + "overcooked_fix_to_single_layout": ["cramped_room_6_9"], + "n_shaped_reward_steps": [5000000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline__s5_forced_coord.json b/src/config/configs/overcooked/baseline__s5_forced_coord.json new file mode 100644 index 0000000..2d436ef --- /dev/null +++ b/src/config/configs/overcooked/baseline__s5_forced_coord.json @@ -0,0 +1,69 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [1000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [true], + "overcooked_fix_to_single_layout": ["forced_coord_6_9"], + "n_shaped_reward_steps": [5000000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_dr_lstm.json b/src/config/configs/overcooked/baseline_dr_lstm.json new file mode 100644 index 0000000..9b97792 --- /dev/null +++ b/src/config/configs/overcooked/baseline_dr_lstm.json @@ -0,0 +1,64 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_dr_lstm5x5.json b/src/config/configs/overcooked/baseline_dr_lstm5x5.json new file mode 100644 index 0000000..6a1af77 --- /dev/null +++ b/src/config/configs/overcooked/baseline_dr_lstm5x5.json @@ -0,0 +1,64 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [5], + "overcooked_width": [5], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing5_5,Overcooked-ForcedCoord5_5,Overcooked-CrampedRoom5_5" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_dr_s5.json b/src/config/configs/overcooked/baseline_dr_s5.json new file mode 100644 index 0000000..91c7864 --- /dev/null +++ b/src/config/configs/overcooked/baseline_dr_s5.json @@ -0,0 +1,68 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_dr_s55x5.json b/src/config/configs/overcooked/baseline_dr_s55x5.json new file mode 100644 index 0000000..7a45833 --- /dev/null +++ b/src/config/configs/overcooked/baseline_dr_s55x5.json @@ -0,0 +1,68 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "overcooked_height": [5], + "overcooked_width": [5], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing5_5,Overcooked-ForcedCoord5_5,Overcooked-CrampedRoom5_5" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_dr_softmoe_lstm.json b/src/config/configs/overcooked/baseline_dr_softmoe_lstm.json new file mode 100644 index 0000000..82de0f9 --- /dev/null +++ b/src/config/configs/overcooked/baseline_dr_softmoe_lstm.json @@ -0,0 +1,67 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_is_soft_moe": [true], + "student_soft_moe_num_experts": [4], + "student_soft_moe_num_slots": [32], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_dr_softmoe_lstm5x5.json b/src/config/configs/overcooked/baseline_dr_softmoe_lstm5x5.json new file mode 100644 index 0000000..663ab72 --- /dev/null +++ b/src/config/configs/overcooked/baseline_dr_softmoe_lstm5x5.json @@ -0,0 +1,67 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_is_soft_moe": [true], + "student_soft_moe_num_experts": [4], + "student_soft_moe_num_slots": [32], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [5], + "overcooked_width": [5], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing5_5,Overcooked-ForcedCoord5_5,Overcooked-CrampedRoom5_5" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_p_accel_lstm.json b/src/config/configs/overcooked/baseline_p_accel_lstm.json new file mode 100644 index 0000000..4a3ae9e --- /dev/null +++ b/src/config/configs/overcooked/baseline_p_accel_lstm.json @@ -0,0 +1,78 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [20], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_p_accel_lstm5x5.json b/src/config/configs/overcooked/baseline_p_accel_lstm5x5.json new file mode 100644 index 0000000..cd95ad1 --- /dev/null +++ b/src/config/configs/overcooked/baseline_p_accel_lstm5x5.json @@ -0,0 +1,78 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [20], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [5], + "overcooked_width": [5], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing5_5,Overcooked-ForcedCoord5_5,Overcooked-CrampedRoom5_5" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_p_accel_s5.json b/src/config/configs/overcooked/baseline_p_accel_s5.json new file mode 100644 index 0000000..3913395 --- /dev/null +++ b/src/config/configs/overcooked/baseline_p_accel_s5.json @@ -0,0 +1,82 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [20], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_p_accel_s55x5.json b/src/config/configs/overcooked/baseline_p_accel_s55x5.json new file mode 100644 index 0000000..db1b950 --- /dev/null +++ b/src/config/configs/overcooked/baseline_p_accel_s55x5.json @@ -0,0 +1,82 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [20], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "overcooked_height": [5], + "overcooked_width": [5], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing5_5,Overcooked-ForcedCoord5_5,Overcooked-CrampedRoom5_5" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_p_accel_softmoe_lstm.json b/src/config/configs/overcooked/baseline_p_accel_softmoe_lstm.json new file mode 100644 index 0000000..1bd059a --- /dev/null +++ b/src/config/configs/overcooked/baseline_p_accel_softmoe_lstm.json @@ -0,0 +1,81 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [20], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_is_soft_moe": [true], + "student_soft_moe_num_experts": [4], + "student_soft_moe_num_slots": [32], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_p_accel_softmoe_lstm5x5.json b/src/config/configs/overcooked/baseline_p_accel_softmoe_lstm5x5.json new file mode 100644 index 0000000..2bc84ac --- /dev/null +++ b/src/config/configs/overcooked/baseline_p_accel_softmoe_lstm5x5.json @@ -0,0 +1,81 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [20], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_is_soft_moe": [true], + "student_soft_moe_num_experts": [4], + "student_soft_moe_num_slots": [32], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [5], + "overcooked_width": [5], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing5_5,Overcooked-ForcedCoord5_5,Overcooked-CrampedRoom5_5" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_p_plr_lstm.json b/src/config/configs/overcooked/baseline_p_plr_lstm.json new file mode 100644 index 0000000..52a4e45 --- /dev/null +++ b/src/config/configs/overcooked/baseline_p_plr_lstm.json @@ -0,0 +1,74 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_p_plr_lstm5x5.json b/src/config/configs/overcooked/baseline_p_plr_lstm5x5.json new file mode 100644 index 0000000..26a9653 --- /dev/null +++ b/src/config/configs/overcooked/baseline_p_plr_lstm5x5.json @@ -0,0 +1,74 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [5], + "overcooked_width": [5], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing5_5,Overcooked-ForcedCoord5_5,Overcooked-CrampedRoom5_5" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_p_plr_s5.json b/src/config/configs/overcooked/baseline_p_plr_s5.json new file mode 100644 index 0000000..6fc3dca --- /dev/null +++ b/src/config/configs/overcooked/baseline_p_plr_s5.json @@ -0,0 +1,78 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_p_plr_s55x5.json b/src/config/configs/overcooked/baseline_p_plr_s55x5.json new file mode 100644 index 0000000..43abe81 --- /dev/null +++ b/src/config/configs/overcooked/baseline_p_plr_s55x5.json @@ -0,0 +1,78 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "overcooked_height": [5], + "overcooked_width": [5], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing5_5,Overcooked-ForcedCoord5_5,Overcooked-CrampedRoom5_5" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_p_plr_softmoe_lstm.json b/src/config/configs/overcooked/baseline_p_plr_softmoe_lstm.json new file mode 100644 index 0000000..1803e44 --- /dev/null +++ b/src/config/configs/overcooked/baseline_p_plr_softmoe_lstm.json @@ -0,0 +1,77 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_is_soft_moe": [true], + "student_soft_moe_num_experts": [4], + "student_soft_moe_num_slots": [32], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_p_plr_softmoe_lstm5x5.json b/src/config/configs/overcooked/baseline_p_plr_softmoe_lstm5x5.json new file mode 100644 index 0000000..29f21e5 --- /dev/null +++ b/src/config/configs/overcooked/baseline_p_plr_softmoe_lstm5x5.json @@ -0,0 +1,77 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_is_soft_moe": [true], + "student_soft_moe_num_experts": [4], + "student_soft_moe_num_slots": [32], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [5], + "overcooked_width": [5], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing5_5,Overcooked-ForcedCoord5_5,Overcooked-CrampedRoom5_5" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_pop_paired_lstm.json b/src/config/configs/overcooked/baseline_pop_paired_lstm.json new file mode 100644 index 0000000..48bd630 --- /dev/null +++ b/src/config/configs/overcooked/baseline_pop_paired_lstm.json @@ -0,0 +1,86 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["paired"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "verbose": [false], + "is_multi_agent": [true], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [2], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["relative_regret"], + "student_gae_lambda": [0.98], + "teacher_discount": [0.999], + "teacher_lr_anneal_steps": [0], + "teacher_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "teacher_entropy_coef": [0.01], + "teacher_value_loss_coef": [0.5], + "teacher_n_unroll_update": [5], + "teacher_ppo_n_epochs": [8], + "teacher_ppo_n_minibatches": [4], + "teacher_ppo_clip_eps": [0.2], + "teacher_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "teacher_model_name": ["default_teacher_cnn"], + "teacher_recurrent_arch": ["lstm"], + "teacher_recurrent_hidden_dim": [64], + "teacher_hidden_dim": [64], + "teacher_n_hidden_layers": [1], + "teacher_n_conv_filters": [128], + "teacher_scalar_embed_dim": [10], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [5], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "overcooked_ued_replace_wall_pos": [true], + "overcooked_ued_fixed_n_wall_steps": [false], + "overcooked_ued_first_wall_pos_sets_budget": [true], + "overcooked_ued_noise_dim": [50], + "overcooked_ued_n_walls": [15], + "overcooked_ued_normalize_obs": [true], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_pop_paired_lstm5x5.json b/src/config/configs/overcooked/baseline_pop_paired_lstm5x5.json new file mode 100644 index 0000000..72eabdd --- /dev/null +++ b/src/config/configs/overcooked/baseline_pop_paired_lstm5x5.json @@ -0,0 +1,86 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["paired"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "verbose": [false], + "is_multi_agent": [true], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [2], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["relative_regret"], + "student_gae_lambda": [0.98], + "teacher_discount": [0.999], + "teacher_lr_anneal_steps": [0], + "teacher_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "teacher_entropy_coef": [0.01], + "teacher_value_loss_coef": [0.5], + "teacher_n_unroll_update": [5], + "teacher_ppo_n_epochs": [8], + "teacher_ppo_n_minibatches": [4], + "teacher_ppo_clip_eps": [0.2], + "teacher_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "teacher_model_name": ["default_teacher_cnn"], + "teacher_recurrent_arch": ["lstm"], + "teacher_recurrent_hidden_dim": [64], + "teacher_hidden_dim": [64], + "teacher_n_hidden_layers": [1], + "teacher_n_conv_filters": [128], + "teacher_scalar_embed_dim": [10], + "overcooked_height": [5], + "overcooked_width": [5], + "overcooked_n_walls": [5], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "overcooked_ued_replace_wall_pos": [true], + "overcooked_ued_fixed_n_wall_steps": [false], + "overcooked_ued_first_wall_pos_sets_budget": [true], + "overcooked_ued_noise_dim": [50], + "overcooked_ued_n_walls": [15], + "overcooked_ued_normalize_obs": [true], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing5_5,Overcooked-ForcedCoord5_5,Overcooked-CrampedRoom5_5" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_pop_paired_s5.json b/src/config/configs/overcooked/baseline_pop_paired_s5.json new file mode 100644 index 0000000..eed05b5 --- /dev/null +++ b/src/config/configs/overcooked/baseline_pop_paired_s5.json @@ -0,0 +1,90 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["paired"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "verbose": [false], + "is_multi_agent": [true], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [2], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["relative_regret"], + "student_gae_lambda": [0.98], + "teacher_discount": [0.999], + "teacher_lr_anneal_steps": [0], + "teacher_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "teacher_entropy_coef": [0.01], + "teacher_value_loss_coef": [0.5], + "teacher_n_unroll_update": [5], + "teacher_ppo_n_epochs": [8], + "teacher_ppo_n_minibatches": [4], + "teacher_ppo_clip_eps": [0.2], + "teacher_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "teacher_model_name": ["default_teacher_cnn"], + "teacher_recurrent_arch": ["lstm"], + "teacher_recurrent_hidden_dim": [64], + "teacher_hidden_dim": [64], + "teacher_n_hidden_layers": [1], + "teacher_n_conv_filters": [128], + "teacher_scalar_embed_dim": [10], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [5], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "overcooked_ued_replace_wall_pos": [true], + "overcooked_ued_fixed_n_wall_steps": [false], + "overcooked_ued_first_wall_pos_sets_budget": [true], + "overcooked_ued_noise_dim": [50], + "overcooked_ued_n_walls": [15], + "overcooked_ued_normalize_obs": [true], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_pop_paired_s55x5.json b/src/config/configs/overcooked/baseline_pop_paired_s55x5.json new file mode 100644 index 0000000..09ef6cf --- /dev/null +++ b/src/config/configs/overcooked/baseline_pop_paired_s55x5.json @@ -0,0 +1,90 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["paired"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "verbose": [false], + "is_multi_agent": [true], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [2], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["relative_regret"], + "student_gae_lambda": [0.98], + "teacher_discount": [0.999], + "teacher_lr_anneal_steps": [0], + "teacher_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "teacher_entropy_coef": [0.01], + "teacher_value_loss_coef": [0.5], + "teacher_n_unroll_update": [5], + "teacher_ppo_n_epochs": [8], + "teacher_ppo_n_minibatches": [4], + "teacher_ppo_clip_eps": [0.2], + "teacher_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [3], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "teacher_model_name": ["default_teacher_cnn"], + "teacher_recurrent_arch": ["lstm"], + "teacher_recurrent_hidden_dim": [64], + "teacher_hidden_dim": [64], + "teacher_n_hidden_layers": [1], + "teacher_n_conv_filters": [128], + "teacher_scalar_embed_dim": [10], + "overcooked_height": [5], + "overcooked_width": [5], + "overcooked_n_walls": [5], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "overcooked_ued_replace_wall_pos": [true], + "overcooked_ued_fixed_n_wall_steps": [false], + "overcooked_ued_first_wall_pos_sets_budget": [true], + "overcooked_ued_noise_dim": [50], + "overcooked_ued_n_walls": [15], + "overcooked_ued_normalize_obs": [true], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing5_5,Overcooked-ForcedCoord5_5,Overcooked-CrampedRoom5_5" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_pop_paired_softmoe_lstm.json b/src/config/configs/overcooked/baseline_pop_paired_softmoe_lstm.json new file mode 100644 index 0000000..c497a5a --- /dev/null +++ b/src/config/configs/overcooked/baseline_pop_paired_softmoe_lstm.json @@ -0,0 +1,89 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["paired"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "verbose": [false], + "is_multi_agent": [true], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [2], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["relative_regret"], + "student_gae_lambda": [0.98], + "teacher_discount": [0.999], + "teacher_lr_anneal_steps": [0], + "teacher_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "teacher_entropy_coef": [0.01], + "teacher_value_loss_coef": [0.5], + "teacher_n_unroll_update": [5], + "teacher_ppo_n_epochs": [8], + "teacher_ppo_n_minibatches": [4], + "teacher_ppo_clip_eps": [0.2], + "teacher_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_is_soft_moe": [true], + "student_soft_moe_num_experts": [4], + "student_soft_moe_num_slots": [32], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "teacher_model_name": ["default_teacher_cnn"], + "teacher_recurrent_arch": ["lstm"], + "teacher_recurrent_hidden_dim": [64], + "teacher_hidden_dim": [64], + "teacher_n_hidden_layers": [1], + "teacher_n_conv_filters": [128], + "teacher_scalar_embed_dim": [10], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [5], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "overcooked_ued_replace_wall_pos": [true], + "overcooked_ued_fixed_n_wall_steps": [false], + "overcooked_ued_first_wall_pos_sets_budget": [true], + "overcooked_ued_noise_dim": [50], + "overcooked_ued_n_walls": [15], + "overcooked_ued_normalize_obs": [true], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/baseline_pop_paired_softmoe_lstm5x5.json b/src/config/configs/overcooked/baseline_pop_paired_softmoe_lstm5x5.json new file mode 100644 index 0000000..3ec6692 --- /dev/null +++ b/src/config/configs/overcooked/baseline_pop_paired_softmoe_lstm5x5.json @@ -0,0 +1,89 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["paired"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "verbose": [false], + "is_multi_agent": [true], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [2], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-4], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["relative_regret"], + "student_gae_lambda": [0.98], + "teacher_discount": [0.999], + "teacher_lr_anneal_steps": [0], + "teacher_gae_lambda": [0.98], + "student_entropy_coef": [0.01], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "teacher_entropy_coef": [0.01], + "teacher_value_loss_coef": [0.5], + "teacher_n_unroll_update": [5], + "teacher_ppo_n_epochs": [8], + "teacher_ppo_n_minibatches": [4], + "teacher_ppo_clip_eps": [0.2], + "teacher_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [64], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_is_soft_moe": [true], + "student_soft_moe_num_experts": [4], + "student_soft_moe_num_slots": [32], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "teacher_model_name": ["default_teacher_cnn"], + "teacher_recurrent_arch": ["lstm"], + "teacher_recurrent_hidden_dim": [64], + "teacher_hidden_dim": [64], + "teacher_n_hidden_layers": [1], + "teacher_n_conv_filters": [128], + "teacher_scalar_embed_dim": [10], + "overcooked_height": [5], + "overcooked_width": [5], + "overcooked_n_walls": [5], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [false], + "overcooked_ued_replace_wall_pos": [true], + "overcooked_ued_fixed_n_wall_steps": [false], + "overcooked_ued_first_wall_pos_sets_budget": [true], + "overcooked_ued_noise_dim": [50], + "overcooked_ued_n_walls": [15], + "overcooked_ued_normalize_obs": [true], + "n_shaped_reward_updates": [30000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing5_5,Overcooked-ForcedCoord5_5,Overcooked-CrampedRoom5_5" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/paired.json b/src/config/configs/overcooked/paired.json new file mode 100644 index 0000000..6a1bb4f --- /dev/null +++ b/src/config/configs/overcooked/paired.json @@ -0,0 +1,83 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [1000000], + "train_runner": ["paired"], + "n_devices": [1], + "student_model_name": ["default_student_actor_moe"], + "student_critic_model_name": ["default_student_critic_moe"], + "env_name": ["Overcooked"], + "verbose": [false], + "is_multi_agent": [true], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [2], + "n_parallel": [100], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["relative_regret"], + "student_gae_lambda": [0.98], + "teacher_discount": [0.995], + "teacher_lr_anneal_steps": [0], + "teacher_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "teacher_entropy_coef": [0.05], + "teacher_value_loss_coef": [0.5], + "teacher_n_unroll_update": [5], + "teacher_ppo_n_epochs": [5], + "teacher_ppo_n_minibatches": [1], + "teacher_ppo_clip_eps": [0.2], + "teacher_ppo_clip_value_loss": [true], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "teacher_model_name": ["default_teacher_cnn"], + "teacher_recurrent_arch": ["lstm"], + "teacher_recurrent_hidden_dim": [256], + "teacher_hidden_dim": [32], + "teacher_n_hidden_layers": [1], + "teacher_n_conv_filters": [128], + "teacher_scalar_embed_dim": [10], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [5], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [true], + "overcooked_ued_replace_wall_pos": [true], + "overcooked_ued_fixed_n_wall_steps": [false], + "overcooked_ued_first_wall_pos_sets_budget": [true], + "overcooked_ued_noise_dim": [50], + "overcooked_ued_n_walls": [15], + "overcooked_ued_normalize_obs": [true], + "test_n_episodes": [10], + "n_shaped_reward_steps": [5000000], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/plr.json b/src/config/configs/overcooked/plr.json new file mode 100644 index 0000000..da68e10 --- /dev/null +++ b/src/config/configs/overcooked/plr.json @@ -0,0 +1,71 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [1000000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_moe"], + "student_critic_model_name": ["default_student_critic_moe"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [100], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [5e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.99], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [false], + "plr_use_parallel_eval": [false], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.0], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [true], + "test_n_episodes": [10], + "n_shaped_reward_steps": [5000000], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/config/configs/overcooked/plr_s5.json b/src/config/configs/overcooked/plr_s5.json new file mode 100644 index 0000000..425944e --- /dev/null +++ b/src/config/configs/overcooked/plr_s5.json @@ -0,0 +1,78 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [100000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_actor_cnn"], + "student_critic_model_name": ["default_student_critic_cnn"], + "env_name": ["Overcooked"], + "is_multi_agent": [true], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [128], + "n_eval": [1], + "n_rollout_steps": [400], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [false], + "plr_use_parallel_eval": [false], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [8], + "student_ppo_n_minibatches": [4], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [64], + "student_n_hidden_layers": [2], + "student_n_conv_layers": [3], + "student_n_conv_filters": [32], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "student_agent_kind": ["mappo"], + "overcooked_height": [6], + "overcooked_width": [9], + "overcooked_n_walls": [15], + "overcooked_replace_wall_pos": [true], + "overcooked_sample_n_walls": [true], + "overcooked_normalize_obs": [true], + "overcooked_max_steps": [400], + "overcooked_random_reset": [true], + "n_shaped_reward_steps": [5000000], + "test_n_episodes": [10], + "test_env_names": [ + "Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9" + ], + "overcooked_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/eval_all_xpid_against_population_in_all_layouts.sh b/src/eval_all_xpid_against_population_in_all_layouts.sh new file mode 100755 index 0000000..fc7bea3 --- /dev/null +++ b/src/eval_all_xpid_against_population_in_all_layouts.sh @@ -0,0 +1,9 @@ +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" + +# "Overcooked-CoordRing6_9" "Overcooked-ForcedCoord6_9" "Overcooked-CounterCircuit6_9" "Overcooked-AsymmAdvantages6_9" "Overcooked-CrampedRoom6_9" +./eval_xpid_against_population_in_all_layouts.sh $device Overcooked-CoordRing6_9 9SEED_9_dr-overcookedNonexNonewNone_fs_FIXcoord_ring_6_9_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr3e-5g0.99cv0.5ce0.01e5mb1l0.95_pc0.2_h64cf32fc2se5ba_re_0 +./eval_xpid_against_population_in_all_layouts.sh $device Overcooked-ForcedCoord6_9 9SEED_9_dr-overcookedNonexNonewNone_fs_FIXforced_coord_6_9_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr3e-5g0.99cv0.5ce0.01e5mb1l0.95_pc0.2_h64cf32fc2se5ba_re_0 +./eval_xpid_against_population_in_all_layouts.sh $device Overcooked-CounterCircuit6_9 9SEED_9_dr-overcookedNonexNonewNone_fs_FIXcounter_circuit_6_9_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr3e-5g0.99cv0.5ce0.01e5mb1l0.95_pc0.2_h64cf32fc2se5ba_re_0 +./eval_xpid_against_population_in_all_layouts.sh $device Overcooked-AsymmAdvantages6_9 9SEED_9_dr-overcookedNonexNonewNone_fs_FIXasymm_advantages_6_9_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr3e-5g0.99cv0.5ce0.01e5mb1l0.95_pc0.2_h64cf32fc2se5ba_re_0 +./eval_xpid_against_population_in_all_layouts.sh $device Overcooked-CrampedRoom6_9 9SEED_9_dr-overcookedNonexNonewNone_fs_FIXcramped_room_6_9_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr3e-5g0.99cv0.5ce0.01e5mb1l0.95_pc0.2_h64cf32fc2se5ba_re_0 \ No newline at end of file diff --git a/src/eval_random_against_population.sh b/src/eval_random_against_population.sh new file mode 100755 index 0000000..6a35889 --- /dev/null +++ b/src/eval_random_against_population.sh @@ -0,0 +1,11 @@ +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" + +for env in "Overcooked-CoordRing6_9" "Overcooked-ForcedCoord6_9" "Overcooked-CounterCircuit6_9" "Overcooked-AsymmAdvantages6_9" "Overcooked-CrampedRoom6_9"; +do + CUDA_VISIBLE_DEVICES=${device} LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.evaluate_baseline_against_population \ + --env_names=${env} \ + --population_json="populations/fcp/${env}/population.json" \ + --n_episodes=100 \ + --is_random=True +done \ No newline at end of file diff --git a/src/eval_stay_against_population.sh b/src/eval_stay_against_population.sh new file mode 100755 index 0000000..5beec1e --- /dev/null +++ b/src/eval_stay_against_population.sh @@ -0,0 +1,10 @@ +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" + +for env in "Overcooked-AsymmAdvantages6_9" "Overcooked-CrampedRoom6_9" "Overcooked-CoordRing6_9" "Overcooked-ForcedCoord6_9" "Overcooked-CounterCircuit6_9"; +do + CUDA_VISIBLE_DEVICES=${device} LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.evaluate_baseline_against_population \ + --env_names=${env} \ + --population_json="populations/fcp/${env}/population.json" \ + --n_episodes=100 +done diff --git a/src/eval_xpid.sh b/src/eval_xpid.sh new file mode 100755 index 0000000..6dbdc81 --- /dev/null +++ b/src/eval_xpid.sh @@ -0,0 +1,6 @@ +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" +CUDA_VISIBLE_DEVICES=${device} LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.evaluate \ +--xpid=$2 \ +--env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--n_episodes=1000 \ No newline at end of file diff --git a/src/eval_xpid_against_population.sh b/src/eval_xpid_against_population.sh new file mode 100755 index 0000000..26d4510 --- /dev/null +++ b/src/eval_xpid_against_population.sh @@ -0,0 +1,10 @@ +DEFAULTVALUE=4 +# Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 +device="${1:-$DEFAULTVALUE}" +ENV=Overcooked-AsymmAdvantages6_9 +XPID=$2 +CUDA_VISIBLE_DEVICES=${device} LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.evaluate_against_population \ +--xpid=${XPID} \ +--env_names=${ENV} \ +--population_json="populations/fcp/${ENV}/population.json" \ +--n_episodes=100 diff --git a/src/eval_xpid_against_population_in_all_layouts.sh b/src/eval_xpid_against_population_in_all_layouts.sh new file mode 100755 index 0000000..15523ee --- /dev/null +++ b/src/eval_xpid_against_population_in_all_layouts.sh @@ -0,0 +1,14 @@ +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" +NAME=$2 +XPID=$3 + +for env in "Overcooked-CoordRing6_9" "Overcooked-ForcedCoord6_9" "Overcooked-CounterCircuit6_9" "Overcooked-AsymmAdvantages6_9" "Overcooked-CrampedRoom6_9"; +do + echo "Evaluating ${NAME} against population in ${env} for xpid ${XPID}" + CUDA_VISIBLE_DEVICES=${device} LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.evaluate_against_population \ + --xpid=${XPID} \ + --env_names=${env} \ + --population_json="populations/fcp/${env}/population.json" \ + --n_episodes=100 +done diff --git a/src/eval_xpid_all_cnn_lstm.sh b/src/eval_xpid_all_cnn_lstm.sh new file mode 100755 index 0000000..1fa1474 --- /dev/null +++ b/src/eval_xpid_all_cnn_lstm.sh @@ -0,0 +1,19 @@ +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" + +./eval_xpid_against_population_in_all_layouts.sh $device DR_CNN-LSTM_SEED1 dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +./eval_xpid_against_population_in_all_layouts.sh $device DR_CNN-LSTM_SEED2 SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +./eval_xpid_against_population_in_all_layouts.sh $device DR_CNN-LSTM_SEED3 SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +./eval_xpid_against_population_in_all_layouts.sh $device PLR_CNN-LSTM_SEED1 plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +./eval_xpid_against_population_in_all_layouts.sh $device PLR_CNN-LSTM_SEED2 SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +./eval_xpid_against_population_in_all_layouts.sh $device PLR_CNN-LSTM_SEED3 SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +./eval_xpid_against_population_in_all_layouts.sh $device PAIRED_CNN-LSTM_SEED1 paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +./eval_xpid_against_population_in_all_layouts.sh $device PAIRED_CNN-LSTM_SEED2 SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +./eval_xpid_against_population_in_all_layouts.sh $device PAIRED_CNN-LSTM_SEED3 SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 + +./eval_xpid_against_population_in_all_layouts.sh $device ACCEL_CNN-LSTM_SEED1 plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +./eval_xpid_against_population_in_all_layouts.sh $device ACCEL_CNN-LSTM_SEED2 SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +./eval_xpid_against_population_in_all_layouts.sh $device ACCEL_CNN-LSTM_SEED3 SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + diff --git a/src/eval_xpid_all_cnn_s5.sh b/src/eval_xpid_all_cnn_s5.sh new file mode 100755 index 0000000..83461ab --- /dev/null +++ b/src/eval_xpid_all_cnn_s5.sh @@ -0,0 +1,19 @@ +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" + +./eval_xpid_against_population_in_all_layouts.sh $device DR_CNN-S5_SEED1 dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +./eval_xpid_against_population_in_all_layouts.sh $device DR_CNN-S5_SEED2 SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +./eval_xpid_against_population_in_all_layouts.sh $device DR_CNN-S5_SEED3 SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + +./eval_xpid_against_population_in_all_layouts.sh $device PLR_CNN-S5_SEED1 plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +./eval_xpid_against_population_in_all_layouts.sh $device PLR_CNN-S5_SEED2 SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +./eval_xpid_against_population_in_all_layouts.sh $device PLR_CNN-S5_SEED3 SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + +./eval_xpid_against_population_in_all_layouts.sh $device PAIRED_CNN-S5_SEED1 paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +./eval_xpid_against_population_in_all_layouts.sh $device PAIRED_CNN-S5_SEED2 SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +./eval_xpid_against_population_in_all_layouts.sh $device PAIRED_CNN-S5_SEED3 SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 + +./eval_xpid_against_population_in_all_layouts.sh $device ACCEL_CNN-S5_SEED1 plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +./eval_xpid_against_population_in_all_layouts.sh $device ACCEL_CNN-S5_SEED2 SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +./eval_xpid_against_population_in_all_layouts.sh $device ACCEL_CNN-S5_SEED3 SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + diff --git a/src/eval_xpid_all_softmoe.sh b/src/eval_xpid_all_softmoe.sh new file mode 100755 index 0000000..d9df838 --- /dev/null +++ b/src/eval_xpid_all_softmoe.sh @@ -0,0 +1,19 @@ +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" + +./eval_xpid_against_population_in_all_layouts.sh $device DR_SoftMoE_SEED1 dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +./eval_xpid_against_population_in_all_layouts.sh $device DR_SoftMoE_SEED2 SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +./eval_xpid_against_population_in_all_layouts.sh $device DR_SoftMoE_SEED3 SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 + +./eval_xpid_against_population_in_all_layouts.sh $device PLR_SoftMoE_SEED1 plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +./eval_xpid_against_population_in_all_layouts.sh $device PLR_SoftMoE_SEED2 SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +./eval_xpid_against_population_in_all_layouts.sh $device PLR_SoftMoE_SEED3 SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 + +./eval_xpid_against_population_in_all_layouts.sh $device PAIRED_SoftMoE_SEED1 paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +./eval_xpid_against_population_in_all_layouts.sh $device PAIRED_SoftMoE_SEED2 SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +./eval_xpid_against_population_in_all_layouts.sh $device PAIRED_SoftMoE_SEED3 SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 + +./eval_xpid_against_population_in_all_layouts.sh $device ACCEL_SoftMoE_SEED1 plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +./eval_xpid_against_population_in_all_layouts.sh $device ACCEL_SoftMoE_SEED2 SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +./eval_xpid_against_population_in_all_layouts.sh $device ACCEL_SoftMoE_SEED3 SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 + diff --git a/src/extract_fcp.sh b/src/extract_fcp.sh new file mode 100755 index 0000000..a4bea62 --- /dev/null +++ b/src/extract_fcp.sh @@ -0,0 +1,14 @@ +DEFAULTVALUE=4 +ENV=Overcooked-CrampedRoom5_5 # Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 +device="${1:-$DEFAULTVALUE}" + +seed_max=8 + +for seed in `seq ${seed_max}`; +do + CUDA_VISIBLE_DEVICES=${device} LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.extract_fcp \ + --xpid=8SEED_${seed}_$2 \ + --env_names=${ENV} \ + --n_episodes=100 \ + --trained_seed=${seed} +done \ No newline at end of file diff --git a/src/make_cmd.sh b/src/make_cmd.sh new file mode 100755 index 0000000..60d0d1d --- /dev/null +++ b/src/make_cmd.sh @@ -0,0 +1 @@ +python3 -m minimax.config.make_cmd --config $1/$2 \ No newline at end of file diff --git a/src/minimax/__init__.py b/src/minimax/__init__.py new file mode 100644 index 0000000..ab54186 --- /dev/null +++ b/src/minimax/__init__.py @@ -0,0 +1,9 @@ +from . import envs +from . import agents +from . import models +from . import runners +from . import util +from . import arguments +from . import evaluate +# from . import train +from . import config diff --git a/src/minimax/agents/__init__.py b/src/minimax/agents/__init__.py new file mode 100644 index 0000000..0c4ba41 --- /dev/null +++ b/src/minimax/agents/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .ppo import PPOAgent +from .mappo import MAPPOAgent + + +__all__ = [ + PPOAgent, MAPPOAgent +] \ No newline at end of file diff --git a/src/minimax/agents/agent.py b/src/minimax/agents/agent.py new file mode 100644 index 0000000..d05b15c --- /dev/null +++ b/src/minimax/agents/agent.py @@ -0,0 +1,40 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from abc import ABC + + +class Agent(ABC): + """ + Generic interface for an agent. + """ + @property + def is_recurrent(self): + pass + + @property + def action_info_keys(self): + pass + + def init_params(self, rng, obs, carry=None): + pass + + def init_carry(self, rng, batch_dims): + pass + + def act(self, *args, **kwargs): + pass + + def get_action_dist(self, dist_params, dtype): + pass + + def evaluate(self, *args, **kwargs): + pass + + def update(self, *args, **kwargs): + pass diff --git a/src/minimax/agents/mappo.py b/src/minimax/agents/mappo.py new file mode 100644 index 0000000..78ace4b --- /dev/null +++ b/src/minimax/agents/mappo.py @@ -0,0 +1,449 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from collections import OrderedDict + +import einops +import jax +import jax.numpy as jnp +import optax +from flax.training.train_state import TrainState +from tensorflow_probability.substrates import jax as tfp + +from .agent import Agent + + +class MAPPOAgent(Agent): + def __init__( + self, + actor, + critic, + n_epochs=5, + n_minibatches=1, + value_loss_coef=0.5, + entropy_coef=0.0, + clip_eps=0.2, + clip_value_loss=True, + track_grad_norm=False, + n_unroll_update=1, + n_devices=1): + + self.actor = actor + self.critic = critic + + self.n_epochs = n_epochs + self.n_minibatches = n_minibatches + self.value_loss_coef = value_loss_coef + self.entropy_coef = entropy_coef + self.clip_eps = clip_eps + self.clip_value_loss = clip_value_loss + self.track_grad_norm = track_grad_norm + self.n_unroll_update = n_unroll_update + self.n_devices = n_devices + + self.actor_grad_fn = jax.value_and_grad(self._actor_loss, has_aux=True) + self.critic_grad_fn = jax.value_and_grad( + self._critic_loss, has_aux=True) + + @property + def is_recurrent(self): + # Actor and Critic need to share arch for now. + return self.actor.is_recurrent + + def init_params(self, rng, obs): + """ + Returns initialized parameters and RNN hidden state for a specific + observation shape. + """ + if len(obs) == 2: + obs, shared_obs = obs + else: + raise ValueError("Obs should always be a two tuple for MAPPO!") + + rng, subrng = jax.random.split(rng) + is_recurrent = self.actor.is_recurrent + if is_recurrent: + batch_size = jax.tree_util.tree_leaves(obs)[0].shape[1] + actor_carry = self.actor.initialize_carry( + rng=subrng, batch_dims=(batch_size,)) + critic_carry = self.critic.initialize_carry( + rng=subrng, batch_dims=(batch_size,)) + reset = jnp.zeros((1, batch_size), dtype=jnp.bool_) + + rng, subrng = jax.random.split(rng) + + # Notice that these are different to later observations but they resemble what we need + actor_params = self.actor.init( + subrng, obs[:, :, 0], actor_carry, reset) + critic_params = self.critic.init( + subrng, shared_obs[:, :, 0], critic_carry, reset) + else: + + obs = jnp.concatenate(obs, axis=0) + shared_obs = jnp.concatenate(shared_obs, axis=0) + actor_params = self.actor.init(subrng, obs, None) + critic_params = self.critic.init(subrng, shared_obs, None) + + return (actor_params, critic_params) + + def init_carry(self, rng, batch_dims=1): + actor_carry = self.actor.initialize_carry( + rng=rng, batch_dims=batch_dims) + # This is for evaluation where we throw away the critic + if self.critic is not None: + critic_carry = self.critic.initialize_carry( + rng=rng, batch_dims=batch_dims) + else: + critic_carry = None + return actor_carry, critic_carry + + @partial(jax.jit, static_argnums=(0,)) + def act(self, actor_params, obs, carry=None, reset=None): + logits, carry = self.actor.apply( + actor_params, obs, carry, reset) + + return None, logits, carry + + @partial(jax.jit, static_argnums=(0,)) + def get_value(self, params, shared_obs, carry=None, reset=None): + value, new_carry = self.critic.apply(params, shared_obs, carry, reset) + return value, new_carry + + @partial(jax.jit, static_argnums=(0,)) + def evaluate_action( + self, actor_params, action, obs, actor_carry=None, reset=None + ): + dist_params, actor_carry = self.actor.apply( + actor_params, obs, actor_carry, reset) + dist = self.get_action_dist(dist_params, dtype=action.dtype) + log_prob = dist.log_prob(action) + entropy = dist.entropy() + + return log_prob.squeeze(), \ + entropy.squeeze(), \ + actor_carry + + @partial(jax.jit, static_argnums=(0,)) + def evaluate(self, params, action, obs, carry=None, reset=None): + value, dist_params, carry = self.model.apply(params, obs, carry, reset) + dist = self.get_action_dist(dist_params, dtype=action.dtype) + log_prob = dist.log_prob(action) + entropy = dist.entropy() + + return value.squeeze(), \ + log_prob.squeeze(), \ + entropy.squeeze(), \ + carry + + def get_action_dist(self, dist_params, dtype=jnp.uint8): + return tfp.distributions.Categorical(logits=dist_params, dtype=dtype) + + @partial(jax.jit, static_argnums=(0,)) + def update(self, rng, train_state, batch): + rngs = jax.random.split(rng, self.n_epochs) + + def _scan_epoch(carry, rng): + brng, urng = jax.random.split(rng) + batch, train_state = carry + minibatches = self._get_minibatches(brng, batch) + train_state, stats = \ + self._update_epoch( + urng, train_state, minibatches) + + return (batch, train_state), stats + + (_, train_state), stats = jax.lax.scan( + _scan_epoch, + (batch, train_state), + rngs, + length=len(rngs) + ) + + stats = jax.tree_util.tree_map(lambda x: x.mean(), stats) + train_state = train_state.increment_updates() + + return train_state, stats + + @partial(jax.jit, static_argnums=(0,)) + def get_empty_update_stats(self): + keys = [ + 'total_loss', # actor_loss + critic_loss + 'actor_loss', # loss_actor - entropy_coef*entropy + 'critic_loss', # value_loss_coef*value_loss + 'actor_loss_actor', # Without the entropy term added + 'actor_l2_reg_weight_loss', + 'actor_entropy', + 'actor_mean_target', + 'actor_mean_gae', + 'critic_value_loss', + 'critic_l2_reg_weight_loss', + 'critic_mean_value', + 'critic_mean_target', + 'critic_mean_gae', + 'actor_grad_norm', + 'critic_grad_norm', + ] + + return OrderedDict({k: -jnp.inf for k in keys}) + + @partial(jax.jit, static_argnums=(0,)) + def _update_epoch( + self, + rng, + train_state: TrainState, + minibatches): + + def _update_minibatch(carry, step): + rng, minibatch = step + train_state = carry + + (actor_loss, actor_aux_info), actor_grads = self.actor_grad_fn( + train_state.actor_params, + train_state.actor_apply_fn, + minibatch, + rng, + ) + + (critic_loss, critic_aux_info), critic_grads = self.critic_grad_fn( + train_state.critic_params, + train_state.critic_apply_fn, + minibatch, + rng, + ) + + total_loss = actor_loss + critic_loss + loss_info = (total_loss, actor_loss, critic_loss,) + \ + actor_aux_info + critic_aux_info + loss_info = loss_info + \ + (optax.global_norm(actor_grads), optax.global_norm(critic_grads),) + + if self.n_devices > 1: + loss_info = jax.tree_map( + lambda x: jax.lax.pmean(x, 'device'), loss_info) + actor_grads = jax.tree_map( + lambda x: jax.lax.pmean(x, 'device'), actor_grads) + critic_grads = jax.tree_map( + lambda x: jax.lax.pmean(x, 'device'), critic_grads) + + train_state = train_state.apply_gradients( + actor_grads=actor_grads, + critic_grads=critic_grads) + + stats_def = jax.tree_util.tree_structure(OrderedDict({ + k: 0 for k in [ + 'total_loss', # actor_loss + critic_loss + 'actor_loss', # loss_actor - entropy_coef*entropy + 'critic_loss', # value_loss_coef*value_loss + 'actor_loss_actor', # Without the entropy term added + 'actor_l2_reg_weight_loss', + 'actor_entropy', + 'actor_mean_target', + 'actor_mean_gae', + 'critic_value_loss', + 'critic_l2_reg_weight_loss', + 'critic_mean_value', + 'critic_mean_target', + 'critic_mean_gae', + 'actor_grad_norm', + 'critic_grad_norm', + ]})) + + loss_stats = jax.tree_util.tree_unflatten( + stats_def, jax.tree_util.tree_leaves(loss_info)) + return train_state, loss_stats + + rngs = jax.random.split(rng, self.n_minibatches) + train_state, loss_stats = jax.lax.scan( + _update_minibatch, + train_state, + (rngs, minibatches), + length=self.n_minibatches, + unroll=self.n_unroll_update + ) + + loss_stats = jax.tree_util.tree_map( + lambda x: x.mean(axis=0), loss_stats) + + return train_state, loss_stats + + @partial(jax.jit, static_argnums=(0, 2, 4)) + def _actor_loss( + self, + params, + apply_fn, + batch, + rng=None + ): + """Currently the shape of elements is n_rollout_steps x n_envs x n_env_agents x ...shape. + This is one more than intended for the actor and critic. The extra dimension is for the + env agents. We thus need to merge it into the n_envs dimension. + """ + carry = None + + if self.is_recurrent: + """ + Elements have batch shape of n_rollout_steps x n_envs//n_minibatches + """ + batch = jax.tree_map( + lambda x: einops.rearrange( + x, 't n a ... -> t (n a) ...'), batch + ) + carry = jax.tree_util.tree_map( + lambda x: x[0, :], batch.actor_carry) + obs, _, action, rewards, dones, log_pi_old, value_old, target, gae, carry_old, _ = batch + + if self.is_recurrent: + dones = dones.at[1:, :].set(dones[:-1, :]) + dones = dones.at[0, :].set(False) + _batch = batch._replace(dones=dones) + + # Returns LxB and LxBxH tensors + obs, _, action, _, done, _, _, _, _, _, _ = _batch + log_pi, entropy, carry = apply_fn( + params, action, obs, carry, done) + else: + log_pi, entropy, carry = apply_fn( + params, action, obs, carry_old) + else: + batch = jax.tree_map( + lambda x: einops.rearrange(x, 'n a ... -> (n a) ...'), batch + ) + obs, _, action, rewards, dones, log_pi_old, value_old, target, gae, _, _ = batch + log_pi, entropy, _ = apply_fn(params, action, obs, carry) + + ratio = jnp.exp(log_pi - log_pi_old) + norm_gae = (gae - gae.mean()) / (gae.std() + 1e-5) + loss_actor1 = ratio * norm_gae + loss_actor2 = jnp.clip(ratio, 1.0 - self.clip_eps, + 1.0 + self.clip_eps) * norm_gae + loss_actor = -jnp.minimum(loss_actor1, loss_actor2).mean() + + entropy = entropy.mean() + + l2_reg_actor = 0.0 + + actor_loss = loss_actor - self.entropy_coef * entropy + l2_reg_actor + + return actor_loss, ( + loss_actor, + l2_reg_actor, + entropy, + target.mean(), + gae.mean() + ) + + @partial(jax.jit, static_argnums=(0, 2, 4)) + def _critic_loss( + self, + params, + apply_fn, + batch, + rng=None + ): + + carry = None + + if self.is_recurrent: + """ + Elements have batch shape of n_rollout_steps x n_envs//n_minibatches + """ + "Same as in actor loss:" + batch = jax.tree_map( + lambda x: einops.rearrange( + x, 't n a ... -> t (n a) ...'), batch + ) + carry = jax.tree_util.tree_map( + lambda x: x[0, :], batch.critic_carry) + _, obs_shared, action, rewards, dones, log_pi_old, value_old, target, gae, _, carry_old = batch + + if self.is_recurrent: + dones = dones.at[1:, :].set(dones[:-1, :]) + dones = dones.at[0, :].set(False) + _batch = batch._replace(dones=dones) + + # Returns LxB and LxBxH tensors + _, obs_shared, action, _, done, _, _, _, _, _, _ = _batch + value, carry = apply_fn( + params, obs_shared, carry, done) + else: + value, carry = apply_fn( + params, obs_shared, carry_old) + value = value.squeeze(-1) + else: + batch = jax.tree_map( + lambda x: einops.rearrange(x, 'n a ... -> (n a) ...'), batch + ) + obs, obs_shared, action, rewards, dones, log_pi_old, value_old, target, gae, _, _ = batch + value, _ = apply_fn(params, obs_shared, carry) + + if self.clip_value_loss: + value_pred_clipped = value_old + (value - value_old).clip( + -self.clip_eps, self.clip_eps + ) + value_losses = jnp.square(value - target) + value_losses_clipped = jnp.square(value_pred_clipped - target) + value_loss = 0.5 * \ + jnp.maximum(value_losses, value_losses_clipped).mean() + else: + value_pred_clipped = value_old + (value - value_old).clip( + -self.clip_eps, self.clip_eps + ) + value_loss = optax.huber_loss( + value_pred_clipped, target, delta=10.0).mean() + + l2_reg_critic = 0.0 + + critic_loss = self.value_loss_coef*value_loss + l2_reg_critic + + return critic_loss, ( + value_loss, + l2_reg_critic, + value.mean(), + target.mean(), + gae.mean() + ) + + @partial(jax.jit, static_argnums=0) + def _get_minibatches(self, rng, batch): + # get dims based on dones + n_rollout_steps, n_envs = batch.dones.shape[0:2] + if self.is_recurrent: + """ + Reshape elements into a batch shape of + n_minibatches x n_envs//n_minibatches x n_rollout_steps. + """ + assert n_envs % self.n_minibatches == 0, \ + 'Number of environments must be divisible into number of minibatches.' + + n_env_per_minibatch = n_envs//self.n_minibatches + shuffled_idx = jax.random.permutation(rng, jnp.arange(n_envs)) + + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, shuffled_idx, axis=1), batch) + + minibatches = jax.tree_util.tree_map( + lambda x: x.swapaxes(0, 1).reshape( + self.n_minibatches, + n_env_per_minibatch, + n_rollout_steps, + *x.shape[2:] + ).swapaxes(1, 2), shuffled_batch) + else: + n_txns = n_envs*n_rollout_steps + assert n_envs*n_rollout_steps % self.n_minibatches == 0 + + shuffled_idx = jax.random.permutation(rng, jnp.arange(n_txns)) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take( + x.reshape(n_txns, *x.shape[2:]), + shuffled_idx, axis=0), batch) + minibatches = jax.tree_util.tree_map( + lambda x: x.reshape(self.n_minibatches, -1, *x.shape[1:]), shuffled_batch) + + return minibatches diff --git a/src/minimax/agents/ppo.py b/src/minimax/agents/ppo.py new file mode 100644 index 0000000..63b6ceb --- /dev/null +++ b/src/minimax/agents/ppo.py @@ -0,0 +1,304 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from typing import Any, Callable, Tuple +from collections import defaultdict, OrderedDict + +import jax +import jax.numpy as jnp +import optax +from flax.training.train_state import TrainState +from tensorflow_probability.substrates import jax as tfp + +from .agent import Agent + + +class PPOAgent(Agent): + def __init__( + self, + model, + n_epochs=5, + n_minibatches=1, + value_loss_coef=0.5, + entropy_coef=0.0, + clip_eps=0.2, + clip_value_loss=True, + track_grad_norm=False, + n_unroll_update=1, + n_devices=1): + + self.model = model + self.n_epochs = n_epochs + self.n_minibatches = n_minibatches + self.value_loss_coef = value_loss_coef + self.entropy_coef = entropy_coef + self.clip_eps = clip_eps + self.clip_value_loss = clip_value_loss + self.track_grad_norm = track_grad_norm + self.n_unroll_update = n_unroll_update + self.n_devices = n_devices + + self.grad_fn = jax.value_and_grad(self._loss, has_aux=True) + + @property + def is_recurrent(self): + return self.model.is_recurrent + + def init_params(self, rng, obs): + """ + Returns initialized parameters and RNN hidden state for a specific + observation shape. + """ + rng, subrng = jax.random.split(rng) + if self.model.is_recurrent: + batch_size = jax.tree_util.tree_leaves(obs)[0].shape[1] + carry = self.model.initialize_carry( + rng=subrng, batch_dims=(batch_size,)) + reset = jnp.zeros((1, batch_size), dtype=jnp.bool_) + rng, subrng = jax.random.split(rng) + params = self.model.init(subrng, obs, carry, reset) + else: + params = self.model.init(subrng, obs) + + return params + + def init_carry(self, rng, batch_dims=(1,)): + return self.model.initialize_carry(rng=rng, batch_dims=batch_dims) + + @partial(jax.jit, static_argnums=(0,)) + def act(self, params, obs, carry=None, reset=None): + value, logits, carry = self.model.apply(params, obs, carry, reset) + + return value, logits, carry + + @partial(jax.jit, static_argnums=(0,)) + def get_value(self, params, obs, carry=None, reset=None): + value, _, carry = self.model.apply(params, obs, carry, reset) + return value, carry + + @partial(jax.jit, static_argnums=(0,)) + def evaluate(self, params, action, obs, carry=None, reset=None): + value, dist_params, carry = self.model.apply(params, obs, carry, reset) + dist = self.get_action_dist(dist_params, dtype=action.dtype) + log_prob = dist.log_prob(action) + entropy = dist.entropy() + + return value.squeeze(), \ + log_prob.squeeze(), \ + entropy.squeeze(), \ + carry + + def get_action_dist(self, dist_params, dtype=jnp.uint8): + return tfp.distributions.Categorical(logits=dist_params, dtype=dtype) + + @partial(jax.jit, static_argnums=(0,)) + def update(self, rng, train_state, batch): + rngs = jax.random.split(rng, self.n_epochs) + + def _scan_epoch(carry, rng): + brng, urng = jax.random.split(rng) + batch, train_state = carry + minibatches = self._get_minibatches(brng, batch) + train_state, stats = \ + self._update_epoch( + urng, train_state, minibatches) + + return (batch, train_state), stats + + (_, train_state), stats = jax.lax.scan( + _scan_epoch, + (batch, train_state), + rngs, + length=len(rngs) + ) + + stats = jax.tree_util.tree_map(lambda x: x.mean(), stats) + train_state = train_state.increment_updates() + + return train_state, stats + + @partial(jax.jit, static_argnums=(0,)) + def get_empty_update_stats(self): + keys = ['total_loss', + 'actor_loss', + 'value_loss', + 'entropy', + 'mean_value', + 'mean_target', + 'mean_gae', + 'grad_norm'] + + return OrderedDict({k: -jnp.inf for k in keys}) + + @partial(jax.jit, static_argnums=(0,)) + def _update_epoch( + self, + rng, + train_state: TrainState, + minibatches): + + def _update_minibatch(carry, step): + rng, minibatch = step + train_state = carry + + (loss, aux_info), grads = self.grad_fn( + train_state.params, + train_state.apply_fn, + minibatch, + rng, + ) + + loss_info = (loss,) + aux_info + loss_info = loss_info + (optax.global_norm(grads),) + + if self.n_devices > 1: + loss_info = jax.tree_map( + lambda x: jax.lax.pmean(x, 'device'), loss_info) + grads = jax.tree_map( + lambda x: jax.lax.pmean(x, 'device'), grads) + + train_state = train_state.apply_gradients(grads=grads) + + stats_def = jax.tree_util.tree_structure(OrderedDict({ + k: 0 for k in [ + 'total_loss', + 'actor_loss', + 'value_loss', + 'entropy', + 'mean_value', + 'mean_target', + 'mean_gae', + 'grad_norm', + ]})) + + loss_stats = jax.tree_util.tree_unflatten( + stats_def, jax.tree_util.tree_leaves(loss_info)) + + return train_state, loss_stats + + rngs = jax.random.split(rng, self.n_minibatches) + train_state, loss_stats = jax.lax.scan( + _update_minibatch, + train_state, + (rngs, minibatches), + length=self.n_minibatches, + unroll=self.n_unroll_update + ) + + loss_stats = jax.tree_util.tree_map( + lambda x: x.mean(axis=0), loss_stats) + + return train_state, loss_stats + + @partial(jax.jit, static_argnums=(0, 2, 4)) + def _loss( + self, + params, + apply_fn, + batch, + rng=None): + carry = None + + if self.is_recurrent: + """ + Elements have batch shape of n_rollout_steps x n_envs//n_minibatches + """ + carry = jax.tree_util.tree_map(lambda x: x[0, :], batch.carry) + obs, action, rewards, dones, log_pi_old, value_old, target, gae, carry_old = batch + + if self.is_recurrent: + dones = dones.at[1:, :].set(dones[:-1, :]) + dones = dones.at[0, :].set(False) + _batch = batch._replace(dones=dones) + + # Returns LxB and LxBxH tensors + obs, action, _, done, _, _, _, _, _ = _batch + value, log_pi, entropy, carry = apply_fn( + params, action, obs, carry, done) + else: + value, log_pi, entropy, carry = apply_fn( + params, action, obs, carry_old) + else: + obs, action, rewards, dones, log_pi_old, value_old, target, gae, _ = batch + value, log_pi, entropy, _ = apply_fn(params, action, obs, carry) + + if self.clip_value_loss: + value_pred_clipped = value_old + (value - value_old).clip( + -self.clip_eps, self.clip_eps + ) + value_losses = jnp.square(value - target) + value_losses_clipped = jnp.square(value_pred_clipped - target) + value_loss = 0.5 * \ + jnp.maximum(value_losses, value_losses_clipped).mean() + else: + value_loss = optax.huber_loss(value, target).mean() + + if self.model.value_ensemble_size > 1: + gae = gae.at[..., 0].get() + + ratio = jnp.exp(log_pi - log_pi_old) + norm_gae = (gae - gae.mean()) / (gae.std() + 1e-5) + loss_actor1 = ratio * norm_gae + loss_actor2 = jnp.clip(ratio, 1.0 - self.clip_eps, + 1.0 + self.clip_eps) * norm_gae + loss_actor = -jnp.minimum(loss_actor1, loss_actor2).mean() + + entropy = entropy.mean() + + total_loss = ( + loss_actor + self.value_loss_coef*value_loss - self.entropy_coef*entropy + ) + + return total_loss, ( + loss_actor, + value_loss, + entropy, + value.mean(), + target.mean(), + gae.mean() + ) + + @partial(jax.jit, static_argnums=0) + def _get_minibatches(self, rng, batch): + # get dims based on dones + n_rollout_steps, n_envs = batch.dones.shape[0:2] + if self.is_recurrent: + """ + Reshape elements into a batch shape of + n_minibatches x n_envs//n_minibatches x n_rollout_steps. + """ + assert n_envs % self.n_minibatches == 0, \ + 'Number of environments must be divisible into number of minibatches.' + + n_env_per_minibatch = n_envs//self.n_minibatches + shuffled_idx = jax.random.permutation(rng, jnp.arange(n_envs)) + + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take(x, shuffled_idx, axis=1), batch) + + minibatches = jax.tree_util.tree_map( + lambda x: x.swapaxes(0, 1).reshape( + self.n_minibatches, + n_env_per_minibatch, + n_rollout_steps, + *x.shape[2:] + ).swapaxes(1, 2), shuffled_batch) + else: + n_txns = n_envs*n_rollout_steps + assert n_envs*n_rollout_steps % self.n_minibatches == 0 + + shuffled_idx = jax.random.permutation(rng, jnp.arange(n_txns)) + shuffled_batch = jax.tree_util.tree_map( + lambda x: jnp.take( + x.reshape(n_txns, *x.shape[2:]), + shuffled_idx, axis=0), batch) + minibatches = jax.tree_util.tree_map( + lambda x: x.reshape(self.n_minibatches, -1, *x.shape[1:]), shuffled_batch) + + return minibatches diff --git a/src/minimax/arguments.py b/src/minimax/arguments.py new file mode 100644 index 0000000..53e0c81 --- /dev/null +++ b/src/minimax/arguments.py @@ -0,0 +1,1023 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse + +from minimax.util.parsnip import Parsnip +from minimax.util.args import str2bool + + +parser = Parsnip() + + +# ==== Define top-level arguments +parser.add_argument( + '--seed', + type=int, + default=1, + help='Training seed.') +parser.add_argument( + '--agent_rl_algo', + type=str, + default='ppo', + choices=['ppo'], + help='Base RL algorithm to use.') +parser.add_argument( + '--n_total_updates', + type=int, + default=30000, + help='Total number of student gradient updates.') +parser.add_argument( + '--train_runner', + type=str, + default='dr', + choices=['dr', 'plr', 'paired'], + help='Algorithm runner.') +parser.add_argument( + '--n_devices', + type=int, + default=1, + help='Number of devices.') +parser.add_argument( + '--is_multi_agent', + type=str2bool, + default=False, + help='Whether multi agent env or not.') +parser.add_argument( + '--n_shaped_reward_steps', + type=int, + default=0, + help='Number of steps to use shaped reward for (linear decreased).') + +parser.add_argument( + '--n_shaped_reward_updates', + type=int, + default=0, + help='Number of steps to use shaped reward for (linear decreased).') + + +# ==== RL runner arguments. +train_runner_subparser = parser.add_subparser( + name='train_runner') +train_runner_subparser.add_argument( + '--n_students', + type=int, + default=1, + help='Number of students in population.') +train_runner_subparser.add_argument( + '--n_parallel', + type=int, + default=1, + help='Number of parallel environments per rollout.') +train_runner_subparser.add_argument( + '--n_eval', + type=int, + default=1, + help='Number of student evaluations per environment.') +train_runner_subparser.add_argument( + '--n_rollout_steps', + type=int, + default=250, + help='Number of rollout steps.') +train_runner_subparser.add_argument( + '--lr', + type=float, + default=1e-4, + help='Initial learning rate.') +train_runner_subparser.add_argument( + '--lr_final', + type=float, + default=None, + nargs="?", + help='Final learning rate.') +train_runner_subparser.add_argument( + '--lr_anneal_steps', + type=int, + default=0, + nargs="?", + help='Number of learning rate annealing steps.') +train_runner_subparser.add_argument( + '--max_grad_norm', + type=float, + default=0.5, + help='max norm of gradients.') +train_runner_subparser.add_argument( + '--adam_eps', + type=float, + default=1e-5, + help='Adam eps.') +train_runner_subparser.add_argument( + '--track_env_metrics', + type=str2bool, + default=False, + help='Track env metrics during training. Can reduce SPS.') +train_runner_subparser.add_argument( + '--discount', + type=float, + default=0.995, + help='Student discount factor for rewards') +train_runner_subparser.add_argument( + '--n_unroll_rollout', + type=int, + default=1, + help='Number of times to unroll rollout scan.') +train_runner_subparser.add_argument( + '--render', + type=str2bool, + default=False, + help='Whether to render.') + +# ------ AC-specific arguments ----- +dr_subparser = parser.add_subparser( + name='dr', + prefix='dr', + dependency={'train_runner': 'dr'}, + dest='train_runner') + +# -------- General UED arguments -------- +parser.add_dependent_argument( + '--ued_score', + type=str, + default='relative_regret', + dependency={'train_runner': ['plr', 'paired']}, + dest='train_runner', + choices=[ + 'relative_regret', + 'mean_relative_regret', + 'population_regret', + 'neg_return', # aka minimax adversarial + 'l1_value_loss', + 'positive_value_loss', + 'max_mc', + 'value_disagreement' + ], + help='UED score of agent.') + +# -------- PAIRED arguments -------- +plr_subparser = parser.add_subparser( + name='plr', + prefix='plr', + dependency={'train_runner': 'plr'}, + dest='train_runner') +plr_subparser.add_argument( + '--replay_prob', + type=float, + default=0.5, + help='PLR replay probability.' +) +plr_subparser.add_argument( + '--buffer_size', + type=int, + default=128, + help='PLR level buffer size.' +) +plr_subparser.add_argument( + '--staleness_coef', + type=float, + default=0.3, + help='Staleness coefficient.' +) +plr_subparser.add_argument( + '--temp', + type=float, + default=1.0, + help='Score distribution temperature.' +) +plr_subparser.add_argument( + '--use_score_ranks', + type=str2bool, + default=True, + help='Use rank-based prioritiziation.' +) +plr_subparser.add_argument( + '--min_fill_ratio', + type=float, + default=0.5, + help='Minimum fill ratio before level replay begins.' +) +plr_subparser.add_argument( + '--use_robust_plr', + type=str2bool, + default=True, + help='Use robust PLR.' +) +plr_subparser.add_argument( + '--use_parallel_eval', + type=str2bool, + default=False, + help='Use rank-based prioritiziation.' +) +plr_subparser.add_argument( + '--force_unique', + type=str2bool, + default=False, + help='Force level buffer members to be unique.' +) +plr_subparser.add_argument( + '--mutation_fn', + type=str, + default=None, + help='Name of mutation function for ACCEL.' +) +plr_subparser.add_argument( + '--n_mutations', + type=int, + default=0, + help='Number of mutations per iteration of ACCEL.' +) +plr_subparser.add_argument( + '--mutation_criterion', + type=str, + default='batch', + help='Criterion for choosing PLR buffer members to mutate.' +) +plr_subparser.add_argument( + '--mutation_subsample_size', + type=int, + default=0, + help='Number of PLR buffer members to mutate into a full batch.' +) + + +# -------- PAIRED arguments -------- +paired_subparser = parser.add_subparser( + name='paired', + prefix='paired', + dependency={'train_runner': 'paired'}, + dest='train_runner') + + +# ==== Student RL arguments. +student_rl_subparser = parser.add_subparser( + name='student_rl', + prefix='student') +student_rl_subparser.add_argument( + '--entropy_coef', + type=float, + default=0.0, + help='entropy term coefficient') +student_rl_subparser.add_argument( + '--value_loss_coef', + type=float, + default=0.5, + help='value loss coefficient (default: 0.5)') +student_rl_subparser.add_argument( + '--n_unroll_update', + type=int, + default=1, + help='Number of times to unroll minibatch scan.') + +# -------- Student PPO arguments. -------- +student_ppo_subparser = parser.add_subparser( + name='student_ppo', + prefix='student_ppo', + dest='student_rl', + dependency={'agent_rl_algo': 'ppo'}) +student_ppo_subparser.add_argument( + '--n_epochs', + type=int, + default=5, + help='Number of PPO epochs.') +student_ppo_subparser.add_argument( + '--n_minibatches', + type=int, + default=1, + help='Number of minibatches per PPO epoch.') +student_ppo_subparser.add_argument( + '--clip_eps', + type=float, + default=0.2, + help='PPO clip parameter') +student_ppo_subparser.add_argument( + '--clip_value_loss', + type=str2bool, + default=True, + help='ppo clip value loss') +parser.add_dependent_argument( + '--gae_lambda', + type=float, + default=0.95, + prefix='student', + dependency={'agent_rl_algo': 'ppo'}, + dest='train_runner', + help='GAE lambda parameter for student.') + + +# ==== Teacher RL arguments. +teacher_rl_subparser = parser.add_subparser( + name='teacher_rl', + prefix='teacher', + dependency={'train_runner': ['paired']}) +teacher_rl_subparser.add_argument( + '--entropy_coef', + type=float, + default=0.0, + help='entropy term coefficient') +teacher_rl_subparser.add_argument( + '--value_loss_coef', + type=float, + default=0.5, + help='value loss coefficient (default: 0.5)') +teacher_rl_subparser.add_argument( + '--n_unroll_update', + type=int, + default=1, + help='Number of times to unroll minibatch scan.') +parser.add_dependent_argument( + '--teacher_discount', + type=float, + default=0.995, + dependency={'train_runner': 'paired'}, + dest='train_runner', + help='discount factor for rewards') +parser.add_dependent_argument( + '--teacher_lr', + type=float, + default=None, + nargs="?", + dependency={'agent_rl_algo': 'ppo', 'train_runner': 'paired'}, + dest='train_runner', + help='Initial learning rate of teacher.') +parser.add_dependent_argument( + '--teacher_lr_final', + type=float, + default=None, + nargs="?", + dependency={'agent_rl_algo': 'ppo', 'train_runner': 'paired'}, + dest='train_runner', + help='Initial learning rate of teacher.') +parser.add_dependent_argument( + '--teacher_lr_anneal_steps', + type=int, + default=0, + nargs="?", + dependency={'agent_rl_algo': 'ppo', 'train_runner': 'paired'}, + dest='train_runner', + help='Initial learning rate of teacher.') + + +# -------- Teacher PPO arguments. -------- +teacher_ppo_subparser = parser.add_subparser( + name='teacher_ppo', + prefix='teacher_ppo', + dest='teacher_rl', + dependency={'agent_rl_algo': 'ppo', 'train_runner': 'paired'}) +teacher_ppo_subparser.add_argument( + '--n_epochs', + type=int, + default=5, + help='Number of PPO epochs.') +teacher_ppo_subparser.add_argument( + '--n_minibatches', + type=int, + default=1, + help='Number of minibatches per PPO epoch.') +teacher_ppo_subparser.add_argument( + '--clip_eps', + type=float, + default=0.2, + help='PPO clip parameter') +teacher_ppo_subparser.add_argument( + '--clip_value_loss', + type=str2bool, + default=True, + help='ppo clip value loss') +parser.add_dependent_argument( + '--teacher_gae_lambda', + type=float, + default=0.95, + dependency={'agent_rl_algo': 'ppo', 'train_runner': 'paired'}, + dest='train_runner', + help='GAE lambda parameter for teacher.') + + +# ==== Student model arguments. +parser.add_argument( + '--student_model_name', + type=str, + default='default_student_cnn', + help='Name of student model architecture.') +parser.add_argument( + '--student_critic_model_name', + type=str, + default=None, + help='Name of student critic model architecture (for MAPPO).') +parser.add_argument( + '--student_agent_kind', + type=str, + default="ppo", + help='PPO vs MAPPO.') + +# Placeholder group for student model args +student_model_parser = parser.add_subparser( + name='student_model', + prefix='student') + +# ---- Maze args for student model ---- +student_maze_model_parser = parser.add_subparser( + name='student_maze_model', + prefix='student', + dest="student_model", + dependency={'env_name': ['Maze*', 'Overcooked*']}) +student_maze_model_parser.add_argument( + '--is_soft_moe', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Whether to use SoftMoE.') +student_maze_model_parser.add_argument( + '--soft_moe_num_experts', + type=int, + default=4, + help='Number of Experts in the SoftMoE layer.') +student_maze_model_parser.add_argument( + '--soft_moe_num_slots', + type=int, + default=32, + help='Number of Slots in the SoftMoE layer.') +student_maze_model_parser.add_argument( + '--recurrent_arch', + type=str, + default=None, + nargs='?', + choices=['gru', 'lstm', 's5'], + help='Student RNN architecture.') +student_maze_model_parser.add_argument( + '--recurrent_hidden_dim', + type=int, + default=0, + help='Student recurrent hidden state size.') +student_maze_model_parser.add_argument( + '--hidden_dim', + type=int, + default=32, + help='Student hidden dimension.') +student_maze_model_parser.add_argument( + '--n_hidden_layers', + type=int, + default=1, + help='Student number of hidden layers in policy/value heads.') +student_maze_model_parser.add_argument( + '--n_conv_layers', + type=int, + default=1, + help='Number of CNN filters for student.') +student_maze_model_parser.add_argument( + '--n_conv_filters', + type=int, + default=16, + help='Number of CNN filters for student.') +student_maze_model_parser.add_argument( + '--n_scalar_embeddings', + type=int, + default=4, + help='Defaults to 4 directional embeddings.') +student_maze_model_parser.add_argument( + '--scalar_embed_dim', + type=int, + default=5, + help='Dimensionality of scalar direction embeddings.') +student_maze_model_parser.add_argument( + '--base_activation', + type=str, + default='relu', + choices=['relu', 'gelu', 'crelu', 'leaky_relu'], + help='Nonlinearity for intermediate layers.') +student_maze_model_parser.add_argument( + '--value_ensemble_size', + type=int, + default=1, + help='Size of value ensemble. Defaults to 1 (no ensemble).') +student_maze_model_parser.add_argument( + '--s5_n_blocks', + type=int, + default=1, + help='Number of S5 blocks.') +student_maze_model_parser.add_argument( + '--s5_n_layers', + type=int, + default=4, + help='Number of S5 encoder layers.') +student_maze_model_parser.add_argument( + '--s5_layernorm_pos', + type=str, + default=None, + help='Layernorm pos in S5.') +student_maze_model_parser.add_argument( + '--s5_activation', + type=str, + default="half_glu1", + choices=["half_glu1", "half_glu2", "full_glu", "gelu"], + help='Number of S5 encoder layers.') + + +# ==== Teacher model arguments. +parser.add_dependent_argument( + '--teacher_model_name', + dependency={'train_runner': ['paired']}, + type=str, + help='Name of teacher model architecture.' +) + +# Placeholder group for teacher model args +teacher_model_parser = parser.add_subparser( + name='teacher_model', + prefix='teacher', + dependency={'train_runner': ['paired']}) + +# ---- Maze args for PAIRED teacher model ---- +teacher_maze_model_parser = parser.add_subparser( + name='teacher_maze_model', + prefix='teacher', + dest="teacher_model", + dependency={'train_runner': 'paired', 'env_name': ['Maze*', 'Overcooked*']}) +teacher_maze_model_parser.add_argument( + '--is_soft_moe', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Whether to use SoftMoE.') +teacher_maze_model_parser.add_argument( + '--soft_moe_num_experts', + type=int, + default=4, + help='Number of Experts in the SoftMoE layer.') +teacher_maze_model_parser.add_argument( + '--soft_moe_num_slots', + type=int, + default=32, + help='Number of Slots in the SoftMoE layer.') +teacher_maze_model_parser.add_argument( + '--recurrent_arch', + type=str, + default=None, + nargs='?', + choices=['gru', 'lstm', 's5'], + help='Teacher RNN architecture.') +teacher_maze_model_parser.add_argument( + '--recurrent_hidden_dim', + type=int, + default=0, + help='Teacher recurrent hidden state size.') +teacher_maze_model_parser.add_argument( + '--hidden_dim', + type=int, + default=32, + help='Teacher hidden dimension.') +teacher_maze_model_parser.add_argument( + '--n_hidden_layers', + type=int, + default=1, + help='Teacher number of hidden layers in policy/value heads.') +teacher_maze_model_parser.add_argument( + '--n_conv_layers', + type=int, + default=1, + help='Number of CNN filters for teacher.') +teacher_maze_model_parser.add_argument( + '--n_conv_filters', + type=int, + default=128, + help='Number of CNN filters for teacher.') +teacher_maze_model_parser.add_argument( + '--scalar_embed_dim', + type=int, + default=10, + help='Dimensionality of time-step embeddings.') +teacher_maze_model_parser.add_argument( + '--base_activation', + type=str, + default='relu', + choices=['relu', 'gelu', 'crelu', 'leaky_relu'], + help='Nonlinearity for intermediate layers.') +teacher_maze_model_parser.add_argument( + '--s5_n_blocks', + type=int, + default=1, + help='Number of S5 blocks.') +teacher_maze_model_parser.add_argument( + '--s5_n_layers', + type=int, + default=4, + help='Number of S5 encoder layers.') +teacher_maze_model_parser.add_argument( + '--s5_layernorm_pos', + type=str, + default=None, + help='Layernorm pos in S5.') +teacher_maze_model_parser.add_argument( + '--s5_activation', + type=str, + default="half_glu1", + choices=["half_glu1", "half_glu2", "full_glu", "gelu"], + help='Number of S5 encoder layers.') + + +# ==== Environment arguments. +parser.add_argument( + '--env_name', + type=str, + default='Maze', + help='Environment to train on') +env_parser = parser.add_subparser( + name='env') + +# -------- UED environment arguments. -------- +ued_env_parser = parser.add_subparser( + name='ued_env') + +# ======== Envoronment-specific subparsers ======== +# -------- Overcooked -------- +env_overcooked_parser = parser.add_subparser( + name='overcooked', + prefix='overcooked', + dependency={'env_name': ['Overcooked', 'Overcooked*']}, + dest='env') +env_overcooked_parser.add_argument( + '--height', + type=int, + default=13, + help='Height of training mazes.') +env_overcooked_parser.add_argument( + '--width', + type=int, + default=13, + help='Width of training mazes.') +env_overcooked_parser.add_argument( + '--random_reset', + type=str2bool, + nargs='?', + const=True, + default=False, + help='If random reset.') +env_overcooked_parser.add_argument( + '--n_walls', + type=int, + default=25, + help='Maximum number of walls in training mazes.') +env_overcooked_parser.add_argument( + '--replace_wall_pos', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Sample wall positions with replacement.') +env_overcooked_parser.add_argument( + '--sample_n_walls', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Uniformly sample n_walls between 0 and n_walls.') +env_overcooked_parser.add_argument( + '--normalize_obs', + type=str2bool, + nargs='?', + const=True, + default=True, + help='Ensure observations are between 0 and 1.') +env_overcooked_parser.add_argument( + '--max_steps', + type=int, + default=400, + help='Maximum number of steps in training episodes.') +env_overcooked_parser.add_argument( + '--fix_to_single_layout', + type=str, + default=None, + help='Fixes Overcooked to a single layout instead of a randome one during reset.') +env_overcooked_parser.add_argument( + '--dense_obs', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Ensure observations are between 0 and 1.') +# -------- Maze -------- +env_maze_parser = parser.add_subparser( + name='maze', + prefix='maze', + dependency={'env_name': ['Maze', 'Maze-MemoryMaze']}, + dest='env') +env_maze_parser.add_argument( + '--height', + type=int, + default=13, + help='Height of training mazes.') +env_maze_parser.add_argument( + '--width', + type=int, + default=13, + help='Width of training mazes.') +env_maze_parser.add_argument( + '--n_walls', + type=int, + default=25, + help='Maximum number of walls in training mazes.') +env_maze_parser.add_argument( + '--replace_wall_pos', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Sample wall positions with replacement.') +env_maze_parser.add_argument( + '--sample_n_walls', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Uniformly sample n_walls between 0 and n_walls.') +# -------- Maze* environments -------- +env_maze_all_parser = parser.add_subparser( + name='maze_all', + prefix='maze', + dependency={'env_name': 'Maze*'}, + dest='env') +env_maze_all_parser.add_argument( + '--see_agent', + type=str2bool, + nargs='?', + const=True, + default=True, + help='Whether the agent sees itself in observations.') +env_maze_all_parser.add_argument( + '--normalize_obs', + type=str2bool, + nargs='?', + const=True, + default=True, + help='Ensure observations are between 0 and 1.') +env_maze_all_parser.add_argument( + '--obs_agent_pos', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Include agent xy pos in observations.') +env_maze_all_parser.add_argument( + '--max_episode_steps', + type=int, + default=250, + help='Maximum number of steps in training episodes.') + +# -------- Maze UED -------- +maze_ued_parser = parser.add_subparser( + name='maze_ued', + prefix='maze_ued', + dependency={'env_name': ['Maze', 'Maze-MemoryMaze'], + 'train_runner': 'paired'}, + dest='ued_env') +maze_ued_parser.add_argument( + '--replace_wall_pos', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Teacher can sample same wall pos multiple times (resulting in variable n_walls).') +maze_ued_parser.add_argument( + '--fixed_n_wall_steps', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Teacher samples exactly n_walls wall positions for each level.') +maze_ued_parser.add_argument( + '--first_wall_pos_sets_budget', + type=str2bool, + nargs='?', + const=True, + default=False, + help='The first wall positional index determines the wall budget.') +maze_ued_parser.add_argument( + '--noise_dim', + type=int, + default=50, + help="Dimension of episodic noise vector injected into the teacher's observation.") +maze_ued_parser.add_argument( + '--n_walls', + type=int, + default=25, + help="Number walls the adversary can place.") +maze_ued_parser.add_argument( + '--set_agent_dir', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Teacher chooses the agent direction on last time step.') +maze_ued_parser.add_argument( + '--normalize_obs', + type=str2bool, + nargs='?', + const=True, + default=True, + help='Normalize teacher observations.') + +# -------- Overcooked UED -------- +overcooked_ued_parser = parser.add_subparser( + name='overcooked_ued', + prefix='overcooked_ued', + dependency={'env_name': ['Overcooked'], 'train_runner': 'paired'}, + dest='ued_env') +overcooked_ued_parser.add_argument( + '--replace_wall_pos', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Teacher can sample same wall pos multiple times (resulting in variable n_walls).') +overcooked_ued_parser.add_argument( + '--fixed_n_wall_steps', + type=str2bool, + nargs='?', + const=True, + default=False, + help='Teacher samples exactly n_walls wall positions for each level.') +overcooked_ued_parser.add_argument( + '--first_wall_pos_sets_budget', + type=str2bool, + nargs='?', + const=True, + default=False, + help='The first wall positional index determines the wall budget.') +overcooked_ued_parser.add_argument( + '--noise_dim', + type=int, + default=50, + help="Dimension of episodic noise vector injected into the teacher's observation.") +overcooked_ued_parser.add_argument( + '--n_walls', + type=int, + default=25, + help="Number walls the adversary can place.") +overcooked_ued_parser.add_argument( + '--normalize_obs', + type=str2bool, + nargs='?', + const=True, + default=True, + help='Normalize teacher observations.') + + +# Logging arguments (All top-level arguments.). +parser.add_argument( + "--verbose", + type=str2bool, + nargs='?', + const=True, + default=False, + help="Print progress to stdout.") +parser.add_argument( + '--xpid', + default='latest', + help='name for the run - prefix to log files') +parser.add_argument( + '--log_dir', + default='~/logs/minimax/', + help='directory to save agent logs') +parser.add_argument( + '--log_interval', + type=int, + default=1, + help='log interval, one log per n updates') +parser.add_argument( + "--from_last_checkpoint", + type=str2bool, + nargs='?', + const=True, + default=False, + help="Begin training from latest checkpoint if available.") +parser.add_argument( + "--checkpoint_interval", + type=int, + default=0, + help="Save model every this many updates.") +parser.add_argument( + "--archive_interval", + type=int, + default=0, + help="Save an archived model every this many updates.") +parser.add_argument( + "--archive_init_checkpoint", + type=str2bool, + nargs='?', + const=True, + default=False, + help="Archive the initial checkpoint.") +parser.add_argument( + '--test_interval', + type=int, + default=10, + help='Evaluate on test envs every this many updates.') + + +# Evaluation args. +eval_parser = parser.add_subparser( + name='eval', + prefix='test') +eval_parser.add_argument( + '--n_episodes', + type=int, + default=10, + help='Number of test episodes per environment') +eval_parser.add_argument( + '--env_names', + type=str, + default=None, + help='Test environments to evaluate on.') +eval_parser.add_argument( + '--agent_idxs', + type=str, + default='*', + help="csv of agents to evaluate. '*' indicates all.") +eval_env_parser = parser.add_subparser( + name='eval_env', + prefix='test_env', +) + +# -------- Overcooked eval arguments. -------- +overcooked_eval_parser = parser.add_subparser( + name='overcooked_eval', + prefix='overcooked_test', + dependency={'env_name': 'Overcooked*'}, + dest='eval_env' +) +overcooked_eval_parser.add_argument( + '--normalize_obs', + type=str2bool, + nargs='?', + const=True, + default=True, + help='Ensures observations are between 0 and 1.') + +# -------- Maze eval arguments. -------- +maze_eval_parser = parser.add_subparser( + name='maze_eval', + prefix='maze_test', + dependency={'env_name': 'Maze*'}, + dest='eval_env' +) +maze_eval_parser.add_argument( + "--see_agent", + type=str2bool, + nargs='?', + const=True, + default=True, + help="Maze observations include the agent.") +maze_eval_parser.add_argument( + '--normalize_obs', + type=str2bool, + nargs='?', + const=True, + default=True, + help='Ensures observations are between 0 and 1.') + + +# -------- wandb arguments. -------- +wandb_parser = parser.add_subparser( + name='wandb', + prefix='wandb') +wandb_parser.add_argument( + "--base_url", + type=str, + default="https://api.wandb.ai", + help='wandb base url' +) +# wandb_parser.add_argument( +# "--api_key", +# type=str, +# default=None, +# help='wandb api key' +# ) +wandb_parser.add_argument( + "--mode", + type=str, + default="offline", + help='Online/offline or other mode' +) +wandb_parser.add_argument( + "--entity", + type=str, + default=None, + help='Team name' +) +wandb_parser.add_argument( + "--project", + type=str, + default='paired', + help='wandb project name for logging' +) +wandb_parser.add_argument( + "--group", + type=str, + default=None, + help='wandb group name for logging' +) diff --git a/src/minimax/config/__init__.py b/src/minimax/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/minimax/config/configs/maze/accel.json b/src/minimax/config/configs/maze/accel.json new file mode 100644 index 0000000..bdd7c22 --- /dev/null +++ b/src/minimax/config/configs/maze/accel.json @@ -0,0 +1,73 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [false], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [20], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.0], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [0], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/minimax/config/configs/maze/dr.json b/src/minimax/config/configs/maze/dr.json new file mode 100644 index 0000000..57c7c98 --- /dev/null +++ b/src/minimax/config/configs/maze/dr.json @@ -0,0 +1,59 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/minimax/config/configs/maze/paccel.json b/src/minimax/config/configs/maze/paccel.json new file mode 100644 index 0000000..10694da --- /dev/null +++ b/src/minimax/config/configs/maze/paccel.json @@ -0,0 +1,73 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [10], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.0], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [0], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/minimax/config/configs/maze/paired.json b/src/minimax/config/configs/maze/paired.json new file mode 100644 index 0000000..ff0a370 --- /dev/null +++ b/src/minimax/config/configs/maze/paired.json @@ -0,0 +1,84 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["paired"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [2], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["relative_regret"], + "student_gae_lambda": [0.98], + "teacher_discount": [0.995], + "teacher_lr_anneal_steps": [0], + "teacher_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "teacher_entropy_coef": [0.05], + "teacher_value_loss_coef": [0.5], + "teacher_n_unroll_update": [5], + "teacher_ppo_n_epochs": [5], + "teacher_ppo_n_minibatches": [1], + "teacher_ppo_clip_eps": [0.2], + "teacher_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "teacher_model_name": ["default_teacher_cnn"], + "teacher_recurrent_arch": ["lstm"], + "teacher_recurrent_hidden_dim": [256], + "teacher_hidden_dim": [32], + "teacher_n_hidden_layers": [1], + "teacher_n_conv_filters": [128], + "teacher_scalar_embed_dim": [10], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [false], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "maze_ued_replace_wall_pos": [true], + "maze_ued_fixed_n_wall_steps": [true], + "maze_ued_first_wall_pos_sets_budget": [false], + "maze_ued_noise_dim": [50], + "maze_ued_n_walls": [60], + "maze_ued_set_agent_dir": [false], + "maze_ued_normalize_obs": [true], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/minimax/config/configs/maze/plr.json b/src/minimax/config/configs/maze/plr.json new file mode 100644 index 0000000..eaeab54 --- /dev/null +++ b/src/minimax/config/configs/maze/plr.json @@ -0,0 +1,69 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [5e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.1], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [false], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.0], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/minimax/config/configs/maze/pplr.json b/src/minimax/config/configs/maze/pplr.json new file mode 100644 index 0000000..071cab2 --- /dev/null +++ b/src/minimax/config/configs/maze/pplr.json @@ -0,0 +1,69 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [false], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.3], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.0], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["lstm"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/minimax/config/configs/maze/s5_accel.json b/src/minimax/config/configs/maze/s5_accel.json new file mode 100644 index 0000000..f2fbf76 --- /dev/null +++ b/src/minimax/config/configs/maze/s5_accel.json @@ -0,0 +1,78 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.3], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [false], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [10], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["post"], + "student_s5_activation": ["half_glu1"], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [0], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "test_agent_idxs": ["\"*\""], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/minimax/config/configs/maze/s5_dr.json b/src/minimax/config/configs/maze/s5_dr.json new file mode 100644 index 0000000..5f688c5 --- /dev/null +++ b/src/minimax/config/configs/maze/s5_dr.json @@ -0,0 +1,63 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["dr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["post"], + "student_s5_activation": ["half_glu1"], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/minimax/config/configs/maze/s5_paccel.json b/src/minimax/config/configs/maze/s5_paccel.json new file mode 100644 index 0000000..d61f0c2 --- /dev/null +++ b/src/minimax/config/configs/maze/s5_paccel.json @@ -0,0 +1,77 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [1e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.8], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.3], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "plr_mutation_fn": ["default"], + "plr_n_mutations": [20], + "plr_mutation_criterion": ["batch"], + "plr_mutation_subsample_size": [4], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.0], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["post"], + "student_s5_activation": ["half_glu1"], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [0], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/minimax/config/configs/maze/s5_paired.json b/src/minimax/config/configs/maze/s5_paired.json new file mode 100644 index 0000000..451bd40 --- /dev/null +++ b/src/minimax/config/configs/maze/s5_paired.json @@ -0,0 +1,94 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["paired"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [2], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [0.0001], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["relative_regret"], + "student_gae_lambda": [0.98], + "teacher_discount": [0.995], + "teacher_lr": [0.0001], + "teacher_lr_anneal_steps": [0], + "teacher_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "teacher_entropy_coef": [0.001], + "teacher_value_loss_coef": [0.5], + "teacher_n_unroll_update": [5], + "teacher_ppo_n_epochs": [5], + "teacher_ppo_n_minibatches": [1], + "teacher_ppo_clip_eps": [0.2], + "teacher_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["post"], + "student_s5_activation": ["half_glu1"], + "teacher_model_name": ["default_teacher_cnn"], + "teacher_recurrent_arch": ["s5"], + "teacher_recurrent_hidden_dim": [256], + "teacher_hidden_dim": [32], + "teacher_n_hidden_layers": [1], + "teacher_n_conv_filters": [32], + "teacher_scalar_embed_dim": [10], + "teacher_s5_n_blocks": [2], + "teacher_s5_n_layers": [2], + "teacher_s5_layernorm_pos": ["post"], + "teacher_s5_activation": ["half_glu1"], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [false], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "maze_ued_replace_wall_pos": [true], + "maze_ued_fixed_n_wall_steps": [true], + "maze_ued_first_wall_pos_sets_budget": [false], + "maze_ued_noise_dim": [50], + "maze_ued_n_walls": [60], + "maze_ued_set_agent_dir": [false], + "maze_ued_normalize_obs": [true], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "test_agent_idxs": ["\"*\""], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/minimax/config/configs/maze/s5_plr.json b/src/minimax/config/configs/maze/s5_plr.json new file mode 100644 index 0000000..05a9146 --- /dev/null +++ b/src/minimax/config/configs/maze/s5_plr.json @@ -0,0 +1,73 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.999], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.3], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [false], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["pre"], + "student_s5_activation": ["half_glu1"], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/minimax/config/configs/maze/s5_pplr.json b/src/minimax/config/configs/maze/s5_pplr.json new file mode 100644 index 0000000..cccdce6 --- /dev/null +++ b/src/minimax/config/configs/maze/s5_pplr.json @@ -0,0 +1,73 @@ +{ + "args": { + "seed": [1], + "agent_rl_algo": ["ppo"], + "n_total_updates": [30000], + "train_runner": ["plr"], + "n_devices": [1], + "student_model_name": ["default_student_cnn"], + "env_name": ["Maze"], + "verbose": [false], + "log_dir": ["~/logs/minimax"], + "log_interval": [10], + "from_last_checkpoint": [true], + "checkpoint_interval": [1000], + "archive_interval": [0], + "archive_init_checkpoint": [false], + "test_interval": [100], + "n_students": [1], + "n_parallel": [32], + "n_eval": [1], + "n_rollout_steps": [256], + "lr": [3e-05], + "lr_anneal_steps": [0], + "max_grad_norm": [0.5], + "adam_eps": [1e-05], + "track_env_metrics": [true], + "discount": [0.995], + "n_unroll_rollout": [10], + "render": [false], + "ued_score": ["max_mc"], + "plr_replay_prob": [0.5], + "plr_buffer_size": [4000], + "plr_staleness_coef": [0.3], + "plr_temp": [0.3], + "plr_use_score_ranks": [true], + "plr_min_fill_ratio": [0.5], + "plr_use_robust_plr": [true], + "plr_use_parallel_eval": [true], + "plr_force_unique": [true], + "student_gae_lambda": [0.98], + "student_entropy_coef": [0.001], + "student_value_loss_coef": [0.5], + "student_n_unroll_update": [5], + "student_ppo_n_epochs": [5], + "student_ppo_n_minibatches": [1], + "student_ppo_clip_eps": [0.2], + "student_ppo_clip_value_loss": [true], + "student_recurrent_arch": ["s5"], + "student_recurrent_hidden_dim": [256], + "student_hidden_dim": [32], + "student_n_hidden_layers": [1], + "student_n_conv_filters": [16], + "student_n_scalar_embeddings": [4], + "student_scalar_embed_dim": [5], + "student_s5_n_blocks": [2], + "student_s5_n_layers": [2], + "student_s5_layernorm_pos": ["post"], + "student_s5_activation": ["half_glu1"], + "maze_height": [13], + "maze_width": [13], + "maze_n_walls": [60], + "maze_replace_wall_pos": [true], + "maze_sample_n_walls": [false], + "maze_see_agent": [false], + "maze_normalize_obs": [true], + "maze_obs_agent_pos": [false], + "maze_max_episode_steps": [250], + "test_n_episodes": [10], + "test_env_names": ["Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze"], + "maze_test_see_agent": [false], + "maze_test_normalize_obs": [true] + } +} \ No newline at end of file diff --git a/src/minimax/config/make_cmd.py b/src/minimax/config/make_cmd.py new file mode 100644 index 0000000..9cca193 --- /dev/null +++ b/src/minimax/config/make_cmd.py @@ -0,0 +1,287 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse +import json +import os +import pathlib + +import numpy as np + +from minimax.util.dotdict import DefaultDotDict +import minimax.config.xpid_maker as xpid_maker + + +def get_wandb_config(): + wandb_config_path = os.path.join(os.path.abspath(os.getcwd()), 'config', 'wandb.json') + if os.path.exists(wandb_config_path): + with open(wandb_config_path, 'r') as config_file: + config = json.load(config_file) + if len(config) == 2: + return { + 'wandb_base_url': config['base_url'], + 'wandb_api_key': config['api_key'], + } + + return {} + + +def generate_train_cmds( + cmd, params, num_trials=1, start_index=0, newlines=False, + xpid_generator=None, xpid_prefix='', + include_wandb_group=False, + count_set=None): + separator = ' \\\n' if newlines else ' ' + + cmds = [] + + if xpid_generator: + params['xpid'] = xpid_generator(cmd, params, xpid_prefix) + if include_wandb_group: + params['wandb_group'] = params['xpid'] + + start_seed = params['seed'] + + for t in range(num_trials): + params['seed'] = start_seed + t + start_index + + _cmd = [f'python -m {cmd}'] + + trial_idx = t + start_index + for k,v in params.items(): + if v is None: + continue + + if k == 'xpid': + v = f'{v}_{trial_idx}' + + assert len(v) < 256, f'{v} exceeds 256 characters!' + + if count_set is not None: + count_set.add(v) + + if v == "*": + v = f'"*"' + + _cmd.append(f'--{k}={v}') + + _cmd = separator.join(_cmd) + + cmds.append(_cmd) + + return cmds + + +def generate_all_params_for_grid(grid, defaults={}): + def update_params_with_choices(prev_params, param, choices): + updated_params = [] + for v in choices: + for p in prev_params: + updated = p.copy() + updated[param] = v + updated_params.append(updated) + + return updated_params + + all_params = [{}] + for param, choices in grid.items(): + all_params = update_params_with_choices(all_params, param, choices) + + full_params = [] + for p in all_params: + d = defaults.copy() + d.update(p) + full_params.append(d) + + return full_params + + +def parse_args(): + parser = argparse.ArgumentParser(description='Make commands') + + parser.add_argument( + '--dir', + type=str, + default='config/configs/', + help='Path to directory with .json configs') + + parser.add_argument( + '--config', '-c', + type=str, + default=None, + help='Name of .json config for hyperparameter search-grid') + + parser.add_argument( + '--n_trials', + type=int, + default=1, + help='Name of .json config for hyperparameter search-grid') + + parser.add_argument( + '--start_index', + default=0, + type=int, + help='Starting trial index of xpid runs') + + parser.add_argument( + '--count', + action='store_true', + help='Print number of generated commands at the end of output.') + + parser.add_argument( + "--checkpoint", + action='store_true', + help='Whether to start from checkpoint' + ) + + parser.add_argument( + "--wandb_base_url", + type=str, + default=None, + help='wandb base url' + ) + parser.add_argument( + "--wandb_api_key", + type=str, + default=None, + help='wandb api key' + ) + parser.add_argument( + '--wandb_project', + type=str, + default=None, + help='wandb project name') + + parser.add_argument( + '--include_wandb_group', + action="store_true", + help='Whether to include wandb group in cmds.') + + return parser.parse_args() + + +def xpid_from_params(cmd, p, prefix=''): + p = DefaultDotDict(p) + + env_info = xpid_maker.get_env_info(p) + runner_info = xpid_maker.get_runner_info(p) + a_algo_info = xpid_maker.get_algo_info(p, role='student') + + a_info = a_algo_info + if cmd != 'finetune': + a_model_info = xpid_maker.get_model_info(p, role='student') + a_info = f"{a_info}_{a_model_info}" + pt_info = '' + else: + pt_agent_info = 'tch' if p.get('ft_teacher') else 'st' + pt_info = f"-{p.get('checkpoint_name', 'checkpoint')}_{pt_agent_info}" + + tch_info = '' + train_runner = p.get('train_runner', 'dr') + if train_runner == 'paired': + tch_algo_info = xpid_maker.get_algo_info(p, role='teacher') + tch_model_info = xpid_maker.get_model_info(p, role='teacher') + tch_info = f"_tch_{tch_algo_info}_{tch_model_info}" + + xpid = f"{train_runner}-{env_info}-{runner_info}-{a_info}{tch_info}{pt_info}" + + return xpid + + +def setup_config_dir(): + config_dir = 'config/configs' + if not os.path.exists(os.path.join(config_dir, 'maze')): + os.makedirs(config_dir, exist_ok=True) + + import shutil + + this_path = os.path.dirname(os.path.abspath(__file__)) + src_path = os.path.join(this_path, 'configs') + + for item in os.listdir(src_path): + src_item = os.path.join(src_path, item) + dst_item = os.path.join(config_dir, item) + + if os.path.isdir(src_item): + shutil.copytree(src_item, dst_item, symlinks=True) + else: + shutil.copy(src_item, dst_item) + + +if __name__ == '__main__': + args = parse_args() + + # Default parameters + params = { + # Not needed. + } + + setup_config_dir() + + json_filename = args.config + if not json_filename.endswith('.json'): + json_filename += '.json' + + grid_path = os.path.join(os.path.expandvars(os.path.expanduser(args.dir)), json_filename) + config = json.load(open(grid_path)) + cmd = config.get('cmd', 'train') + grid = config['args'] + xpid_prefix = '' if 'xpid_prefix' not in config else config['xpid_prefix'] + + if args.checkpoint: + params['checkpoint'] = True + + if 'wandb_project' in grid: + params['wandb_project'] = args.wandb_project + + if args.wandb_base_url: + params['wandb_base_url'] = args.wandb_base_url + if args.wandb_api_key: + params['wandb_api_key'] = args.wandb_api_key + + params.update(get_wandb_config()) + + # Generate all parameter combinations within grid, using defaults for fixed params + all_params = generate_all_params_for_grid(grid, defaults=params) + + unique_xpids = None + if args.count: + unique_xpids = set() + + # Print all commands + if cmd == 'eval': + xpid_generator = None + else: + xpid_generator = xpid_from_params + count = 0 + for p in all_params: + cmds = generate_train_cmds( + cmd, p, + num_trials=args.n_trials, + start_index=args.start_index, + newlines=True, + xpid_generator=xpid_generator, + xpid_prefix=xpid_prefix, + include_wandb_group=args.include_wandb_group, + count_set=unique_xpids) + + for c in cmds: + print(c + '\n') + count += 1 + + if args.count: + print(f'Generated {len(unique_xpids)} unique commands.') + print('Sweep over') + grid_sizes = [] + for k,v in grid.items(): + if len(v) > 1: + grid_sizes.append(len(v)) + print(f'{k}: {len(v)}') + + print(f'Total num settings: {np.prod(grid_sizes)}') + diff --git a/src/minimax/config/xpid_maker.py b/src/minimax/config/xpid_maker.py new file mode 100644 index 0000000..6d140f6 --- /dev/null +++ b/src/minimax/config/xpid_maker.py @@ -0,0 +1,328 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + + +def _get_base_role(role): + return role.removesuffix('_tch').removesuffix('_st') + + +def _get_runner_info(p): + n_students = p.get('n_students', 1) + n_eval = p.get('n_eval', 1) + + n_devices = p.get('n_devices', 1) + device_info = '' + if n_devices > 1: + device_info = f'_d{n_devices}' + + return f"r{n_students}s_{p.n_parallel}p_{n_eval}e_{p.n_rollout_steps}t_ae{p.adam_eps}{device_info}" + + +def _get_runner_info_dr(p): + ac_info = _get_runner_info(p) + + reset_info = "" + if p.ac_reset_env_on_rollout: + reset_info = f"r" + if len(reset_info) > 0: + reset_info = f"_{reset_info}" + + return f"{ac_info}{reset_info}" + + +def _get_ued_runner_info(p): + info = _get_runner_info(p) + + if p.ued_score == 'relative_regret': + ued_score = 'r' + elif p.ued_score == 'mean_relative_regret': + ued_score = 'mr' + elif p.ued_score == 'population_regret': + ued_score = 'p' + elif p.ued_score == 'neg_return': + ued_score = 'nr' + elif p.ued_score == 'l1_value_loss': + ued_score = 'l1v' + elif p.ued_score == 'positive_value_loss': + ued_score = 'pvl' + elif p.ued_score == 'max_mc': + ued_score = 'mm' + elif p.ued_score == 'value_disagreement': + ued_score = 'vd' + else: + raise ValueError(f'Unsupported ued_score {ued_score}') + + info = f"{info}_s{ued_score}" + + return info + + +def _get_plr_runner_info(p): + info = _get_ued_runner_info(p) + + plr_info = f'p{p.plr_replay_prob}b{p.plr_buffer_size}t{p.plr_temp}s{p.plr_staleness_coef}m{p.plr_min_fill_ratio}' + if p.plr_use_score_ranks: + plr_info = f'{plr_info}r' + + if p.plr_mutation_fn: + plr_info = f'{plr_info}_m{p.plr_mutation_fn[:3]}{p.plr_n_mutations}{p.plr_mutation_criterion[:3]}' + if p.plr_mutation_criterion != "batch": + plr_info = f"{plr_info}{p.plr_mutation_subsample_size}" + + plr_prefix = '' + if p.plr_use_robust_plr: + plr_prefix += 'r' + if p.plr_use_parallel_eval: + plr_prefix += 'p' + if p.plr_force_unique: + plr_prefix += 'f' + if len(plr_prefix) > 0: + plr_prefix += '_' + + return f"{plr_prefix}{plr_info}_{info}" + + +def _get_runner_info_paired(p): + return _get_ued_runner_info(p) + + +def _get_env_info_default(p): + return p.env_name.lower().replace('-', '_') + + +def _get_env_info_maze(p): + see_agent = 'na' if not p.maze_see_agent else '' + + placement_info = "" + if p.maze_replace_wall_pos: + placement_info = f'f' + if p.maze_sample_n_walls: + placement_info = f"{placement_info}s" + if len(placement_info) > 0: + placement_info = f"_{placement_info}" + + return f"{p.env_name}{p.maze_height}x{p.maze_width}w{p.maze_n_walls}{see_agent}{placement_info}" + + +def _get_env_info_overcooked(p): + placement_info = "" + if p.overcooked_replace_wall_pos: + placement_info = f'f' + if p.overcooked_sample_n_walls: + placement_info = f"{placement_info}s" + if len(placement_info) > 0: + placement_info = f"_{placement_info}" + + if p.overcooked_fix_to_single_layout: + fix_to_single_layout_info = f"_FIX{p.overcooked_fix_to_single_layout}" + else: + fix_to_single_layout_info = '' + + if p.overcooked_dense_obs: + use_dense = '_DENSE' + else: + use_dense = '_IMAGE' + + return f"{p.env_name}{p.overcooked_height}x{p.overcooked_width}w{p.overcooked_n_walls}{placement_info}{fix_to_single_layout_info}{use_dense}" + + +def _get_env_info_maze_ued(p): + see_agent = 'na' if not p.maze_see_agent else '' + + info = f"_{see_agent}_ld{p.maze_ued_noise_dim}" + + placement_info = "" + if p.maze_ued_fixed_n_wall_steps: + placement_info = f"f" + if p.maze_ued_replace_wall_pos: + placement_info = f"{placement_info}r" + if p.maze_ued_set_agent_dir: + placement_info = f"{placement_info}d" + if p.maze_ued_first_wall_pos_sets_budget: + placement_info = f"{placement_info}b" + if len(placement_info) > 0: + placement_info = f"_{placement_info}" + info = f"{info}{placement_info}" + + return f"{p.env_name}{p.maze_height}x{p.maze_width}w{p.maze_n_walls}{info}" + + +def _get_env_info_overcooked_ued(p): + info = f"_ld{p.overcooked_ued_noise_dim}" + + placement_info = "" + if p.overcooked_ued_fixed_n_wall_steps: + placement_info = f"f" + if p.overcooked_ued_replace_wall_pos: + placement_info = f"{placement_info}r" + if p.overcooked_ued_set_agent_dir: + placement_info = f"{placement_info}d" + if p.overcooked_ued_first_wall_pos_sets_budget: + placement_info = f"{placement_info}b" + if len(placement_info) > 0: + placement_info = f"_{placement_info}" + info = f"{info}{placement_info}" + + return f"{p.env_name}{p.overcooked_height}x{p.overcooked_width}w{p.overcooked_n_walls}{info}" + + +def _get_model_info_maze_default(p, role): + model_info = '' + if f'{role}_recurrent_arch' in p and p[f'{role}_recurrent_arch'] is not None: + model_info = f"{p[f'{role}_recurrent_arch']}_h{p[f'{role}_recurrent_hidden_dim']}" + + if p[f'{role}_recurrent_arch'] == 's5': + model_info = f"{model_info}nb{p.get(f'{role}_s5_n_blocks', 1)}nl{p.get(f'{role}_s5_n_layers',4)}" + + activation = p.get(f'{role}_s5_activation') + if activation == 'half_glu1': + activation = 'hg1' + elif activation == 'half_glu2': + activation = 'hg2' + elif activation == 'full_glu': + activation = 'fg' + elif activation == 'gelu': + activation = 'g' + else: + activation = 'hg1' + model_info = f'a{activation}_{model_info}' + + ln_key = f'{role}_s5_layernorm_pos' + ln_info = None + if ln_key in p: + ln = p[ln_key] + if ln == 'pre': + ln_info = 'pr' + elif ln == 'post': + ln_info = 'po' + + if ln_info is not None: + model_info = f"l{ln_info}_{model_info}" + + if f'{role}_is_soft_moe' in p: + num_experts = p.get(f'{role}_soft_moe_num_experts') + num_slots = p.get(f'{role}_soft_moe_num_slots') + model_info = f'{model_info}__SoftMoE_{num_experts}E_{num_slots}S__' + + model_info = f'_{model_info}' if len(model_info) > 0 else '' + + value_info = '' + value_ensemble_key = f'{role}_value_ensemble_size' + value_ensemble_size = p.get(value_ensemble_key) + if value_ensemble_size and value_ensemble_size > 1: + value_info = f've{value_ensemble_size}' + + base_activation = p.get(f'{role}_base_activation', 'relu')[:2] + + model_info = f"h{p[f'{role}_hidden_dim']}cf{p[f'{role}_n_conv_filters']}fc{p[f'{role}_n_hidden_layers']}se{p[f'{role}_scalar_embed_dim']}ba_{base_activation}{model_info}{value_info}" + + return model_info + + +def _get_algo_info_ppo(p, role): + if role == 'student': + lr = str(p.lr) + if 'lr_final' in p: + lr_final = '' if p.lr_final is None or p.lr_final == p.lr else str( + p.lr_final) + if len(lr_final) > 0: + lr = f"{lr}_{lr_final}" + + if "n_shaped_reward_steps" in p: + lr = f"{lr}_SRS{p.n_shaped_reward_steps}" + elif "n_shaped_reward_updates" in p: + lr = f"{lr}_SRU{p.n_shaped_reward_updates}" + + return f"ppo_lr{lr}g{p.discount}cv{p.student_value_loss_coef}ce{p.student_entropy_coef}e{p.student_ppo_n_epochs}mb{p.student_ppo_n_minibatches}l{p.student_gae_lambda}_pc{p.student_ppo_clip_eps}" + else: + if 'teacher_lr' in p: + teacher_lr = str( + p.lr) if p.teacher_lr is None else str(p.teacher_lr) + else: + teacher_lr = str(p.lr) + + if 'teacher_lr_final' in p: + teacher_lr_final = str( + p.lr_final) if p.teacher_lr_final is None else str(p.teacher_lr_final) + else: + teacher_lr_final = str(p.lr_final) if 'lr_final' in p else '' + + if teacher_lr_final == teacher_lr: + teacher_lr_final = '' + + if len(teacher_lr_final) > 0: + teacher_lr = f"{teacher_lr}_{teacher_lr_final}" + + return f"ppo_lr{teacher_lr}g{p.teacher_discount}cv{p.teacher_value_loss_coef}ce{p.teacher_entropy_coef}e{p.teacher_ppo_n_epochs}mb{p.teacher_ppo_n_minibatches}l{p.teacher_gae_lambda}pc{p.teacher_ppo_clip_eps}" + + +# ============================================================ + +RUNNER_INFO_HANDLERS = { + 'dr': _get_runner_info_dr, + 'plr': _get_plr_runner_info, + 'paired': _get_runner_info_paired, +} + +ENV_INFO_HANDLERS = { + 'maze': _get_env_info_maze, + 'maze_ued': _get_env_info_maze_ued, + 'overcooked': _get_env_info_overcooked, + 'overcooked_ued': _get_env_info_overcooked_ued +} + +MODEL_INFO_HANDLERS = { + 'maze': { + 'default_student_cnn': partial(_get_model_info_maze_default, role='student'), + 'default_teacher_cnn': partial(_get_model_info_maze_default, role='teacher'), + }, + 'overcooked': { + 'default_student_cnn': partial(_get_model_info_maze_default, role='student'), + 'default_student_actor_cnn': partial(_get_model_info_maze_default, role='student'), + 'default_student_critic_cnn': partial(_get_model_info_maze_default, role='student'), + 'default_student_actor_mlp': partial(_get_model_info_maze_default, role='student'), + 'default_student_critic_mlp': partial(_get_model_info_maze_default, role='student'), + 'default_student_actor_moe': partial(_get_model_info_maze_default, role='student'), + 'default_student_critic_moe': partial(_get_model_info_maze_default, role='student'), + 'default_teacher_cnn': partial(_get_model_info_maze_default, role='teacher'), + } +} + +ALGO_INFO_HANDLERS = { + 'ppo': _get_algo_info_ppo +} + + +def get_runner_info(p): + return RUNNER_INFO_HANDLERS[p.get('train_runner', 'dr')](p) + + +def get_env_info(p): + p.env_name = p.env_name.lower() + env_name = p.env_name + if p.train_runner in ['paired',]: + env_name = f'{env_name}_ued' + + return ENV_INFO_HANDLERS.get( + env_name, _get_env_info_default + )(p) + + +def get_model_info(p, role='student'): + model_name = p.get(f'{role}_model_name') + if model_name is None: + model_name = p['student_model_name'] + env_name = p.env_name.lower().split('-')[0] + + return MODEL_INFO_HANDLERS[env_name][model_name](p) + + +def get_algo_info(p, role='student'): + return ALGO_INFO_HANDLERS[p.agent_rl_algo](p, role) diff --git a/src/minimax/count_params.py b/src/minimax/count_params.py new file mode 100644 index 0000000..cde9707 --- /dev/null +++ b/src/minimax/count_params.py @@ -0,0 +1,133 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import json +import re +import fnmatch +import sys +from collections import defaultdict + +import numpy as np +import pandas as pd +import scipy.stats as spstats +import jax +import jax.numpy as jnp +from tqdm import tqdm + +from minimax.util.parsnip import Parsnip +from minimax.util.checkpoint import load_pkl_object, load_config +from minimax.util.loggers import HumanOutputFormat +from minimax.util.rl import AgentPop +import minimax.models as models +import minimax.agents as agents + + +parser = Parsnip() + +# ==== Define top-level arguments +parser.add_argument( + '--seed', + type=int, + default=1, + help='Random seed.') +parser.add_argument( + '--log_dir', + type=str, + default='~/logs/minimax', + help='Log directory containing experiment dirs.') +parser.add_argument( + '--xpid', + type=str, + default='latest', + help='Experiment ID dir name for model.') +parser.add_argument( + '--xpid_prefix', + type=str, + default=None, + help='Experiment ID dir name for model.') +parser.add_argument( + '--checkpoint_name', + type=str, + default='checkpoint', + help='Name of checkpoint .pkl.') +parser.add_argument( + '--agent_idxs', + type=str, + default='*', + help="Indices of agents to evaluate. '*' indicates all.") + + +if __name__ == '__main__': + """ + Usage: + python -m eval \ + --xpid= \ + --env_names="Maze-SixteenRooms" \ + --n_episodes=100 \ + --agent_idxs=0 + """ + args = parser.parse_args() + + log_dir_path = os.path.expandvars(os.path.expanduser(args.log_dir)) + + xpids = [] + if args.xpid_prefix is not None: + # Get all matching xpid directories + all_xpids = fnmatch.filter(os.listdir( + log_dir_path), f"{args.xpid_prefix}*") + filter_re = re.compile('.*_[0-9]*$') + xpids = [x for x in all_xpids if filter_re.match(x)] + else: + xpids = [args.xpid] + + pbar = tqdm(total=len(xpids)) + + all_eval_stats = defaultdict(list) + for xpid in xpids: + xpid_dir_path = os.path.join(log_dir_path, xpid) + checkpoint_path = os.path.join( + xpid_dir_path, f'{args.checkpoint_name}.pkl') + meta_path = os.path.join(xpid_dir_path, f'meta.json') + + # Load checkpoint info + if not os.path.exists(meta_path): + print(f'Configuration at {meta_path} does not exist. Skipping...') + continue + + if not os.path.exists(checkpoint_path): + print( + f'Checkpoint path {checkpoint_path} does not exist. Skipping...') + continue + + xp_args = load_config(meta_path) + + agent_idxs = args.agent_idxs + if agent_idxs == '*': + agent_idxs = np.arange(xp_args.train_runner_args.n_students) + else: + agent_idxs = \ + np.array([int(x) for x in agent_idxs.split(',')]) + assert np.max(agent_idxs) <= xp_args.train_runner_args.n_students, \ + 'Agent index is out of bounds.' + + runner_state = load_pkl_object(checkpoint_path) + if "params" in runner_state[1].keys(): + params = runner_state[1]['params'] + elif "actor_params" in runner_state[1].keys(): + params = runner_state[1]['actor_params'] + else: + raise ValueError("No params found in checkpoint.") + + params = jax.tree_util.tree_map( + lambda x: jnp.take(x, indices=agent_idxs, axis=0), + params + ) + + param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) + print(f"Model has {param_count} parameters.") diff --git a/src/minimax/envs/__init__.py b/src/minimax/envs/__init__.py new file mode 100644 index 0000000..6bf1d79 --- /dev/null +++ b/src/minimax/envs/__init__.py @@ -0,0 +1,15 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .maze import Maze, UEDMaze +from .batch_env import BatchEnv +from .batch_env_ued import BatchUEDEnv + +from .overcooked_proc import Overcooked, UEDOvercooked + +from .registration import make, get_comparator, get_mutator \ No newline at end of file diff --git a/src/minimax/envs/batch_env.py b/src/minimax/envs/batch_env.py new file mode 100644 index 0000000..5d342c9 --- /dev/null +++ b/src/minimax/envs/batch_env.py @@ -0,0 +1,74 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + +import jax +import jax.numpy as jnp + +import minimax.envs as envs + + +class BatchEnv: + def __init__( + self, + env_name, + n_parallel, + n_eval, + env_kwargs, + wrappers=['monitor_return']): + self.env_name = env_name + self.env, self.env_params = envs.make( + env_name, + env_kwargs=env_kwargs, + wrappers=wrappers, + ) + self.n_parallel = n_parallel + self.n_eval = n_eval + + self.sub_batch_size = n_parallel*n_eval + + self.step = jax.vmap(self._step, in_axes=0) + self.get_env_metrics = jax.vmap(self._get_env_metrics, in_axes=0) + self.set_state = jax.vmap(self._set_state, in_axes=0) + + @partial(jax.jit, static_argnums=(0, 2, 3)) + def reset(self, rng, n_parallel=None, n_eval=None): + return jax.vmap(self._reset, in_axes=(0, None, None))(rng, n_parallel, n_eval) + + def _reset(self, rng, n_parallel=None, n_eval=None): + # Create n_parallel envs, repeated n_eval times + if n_parallel is None: + n_parallel = self.n_parallel + + if n_eval is None: + n_eval = self.n_eval + + brngs = jnp.repeat(jax.random.split(rng, n_parallel), n_eval, axis=0) + + obs, state, extra = jax.vmap( + self.env.reset, in_axes=(0,))(brngs) + + return obs, state, extra + + @partial(jax.jit, static_argnums=0) + def _step(self, rng, state, action, extra): + brngs = jax.random.split(rng, self.sub_batch_size) + return jax.vmap(self.env.step, in_axes=(0, 0, 0, 0, 0))( + brngs, state, action, None, extra) + + @partial(jax.jit, static_argnums=(0,)) + def _get_env_metrics(self, state): + return jax.vmap(self.env.get_env_metrics, in_axes=(0,))(state) + + @partial(jax.jit, static_argnums=(0,)) + def _set_state(self, state): + # Need to repeat the state + state = jax.tree_map(lambda x: x.repeat(self.n_eval, axis=0), state) + + return jax.vmap(self.env.set_state)(state) diff --git a/src/minimax/envs/batch_env_ued.py b/src/minimax/envs/batch_env_ued.py new file mode 100644 index 0000000..497b0c7 --- /dev/null +++ b/src/minimax/envs/batch_env_ued.py @@ -0,0 +1,134 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + +import jax +import jax.numpy as jnp + +import minimax.envs as envs + + +class BatchUEDEnv: + """ + Wraps and batches a UEDEnvironment in + its private methods as follows: + + For student MDP: + Manages a batch of n_parallel x n_eval envs + + For teacher MDP: + Manages a batch of n_parallel envs. + + The public interface vmaps the private methods over + an additional agent population dimension. + """ + def __init__( + self, + env_name, + n_parallel, + n_eval, + env_kwargs, + ued_env_kwargs, + wrappers=['monitor_return'], + ued_wrappers=None): + self.wrappers = wrappers + self.env, self.env_params, self.ued_params = \ + envs.make( + env_name, + env_kwargs=env_kwargs, + ued_env_kwargs=ued_env_kwargs, + wrappers=wrappers, + ued_wrappers=ued_wrappers + ) + + self.n_parallel = n_parallel + self.n_eval = n_eval + self.sub_batch_size = n_parallel*n_eval + + self.reset_student = jax.vmap(self._reset_student, in_axes=(0,0,None)) + self.step_teacher = jax.vmap(self._step_teacher, in_axes=0) + self.step_student = jax.vmap(self._step_student, in_axes=0) + + self.set_env_instance = jax.vmap(self._set_env_instance, in_axes=0) + self.get_env_metrics = jax.vmap(self._get_env_metrics, in_axes=0) + + partial(jax.jit, static_argnums=(2,)) + def reset(self, rng, sub_batch_size=None): + if sub_batch_size is None: + sub_batch_size = self.sub_batch_size + + return jax.vmap(self._reset, in_axes=(0,None))(rng, sub_batch_size) + + def _reset(self, rng, sub_batch_size): + brngs = jax.random.split(rng, sub_batch_size) + return jax.vmap(self.env.reset)(brngs) + + partial(jax.jit, static_argnums=(2,)) + def reset_teacher(self, rng, n_parallel=None): + if n_parallel is None: + n_parallel = self.n_parallel + + return jax.vmap(self._reset_teacher, in_axes=(0,None))(rng, n_parallel) + + def _reset_teacher(self, rng, n_parallel): + """ + Reset n_parallel envs + """ + brngs = jax.random.split(rng, n_parallel) + return jax.vmap(self.env.reset_teacher)(brngs) + + def _step_teacher(self, rng, ued_state, action, extra=None): + """ + Step n_parallel envs + """ + brngs = jax.random.split(rng, self.n_parallel) + step_args = (brngs, ued_state, action) + if extra is not None: + step_args += (extra,) + + return jax.vmap(self.env.step_teacher)(*step_args) + + def _reset_student(self, rng, ued_state, n_students): + """ + Reset the student MDP based on the state of the teacher MDP. + """ + brngs = jax.random.split(rng, self.n_parallel) + obs, state, extra = \ + jax.vmap(self.env.reset_student)(brngs, ued_state) + + obs = jax.tree_util.tree_map( + lambda x: jnp.repeat( + jnp.expand_dims(jnp.repeat(x, self.n_eval, 0), 0), n_students, 0), obs) + + state = jax.tree_util.tree_map( + lambda x: jnp.repeat( + jnp.expand_dims(jnp.repeat(x, self.n_eval, 0), 0), n_students, 0), state) + + extra = jax.tree_util.tree_map( + lambda x: jnp.repeat( + jnp.expand_dims(jnp.repeat(x, self.n_eval, 0), 0), n_students, 0), extra) + + return obs, state, extra + + def _step_student(self, rng, state, action, reset_state, extra=None): + """ + Step the student MDP. + """ + brngs = jax.random.split(rng, self.sub_batch_size) + return jax.vmap(self.env.step)(brngs, state, action, reset_state, extra) + + def _set_env_instance(self, instance): + """ + Reset the student MDP to a particular configuration, + captured by state argument. Used for PLR. + """ + return jax.vmap(self.env.set_env_instance)(instance) + + def _get_env_metrics(self, state): + return jax.vmap(self.env.get_env_metrics)(state) diff --git a/src/minimax/envs/environment.py b/src/minimax/envs/environment.py new file mode 100644 index 0000000..2719d60 --- /dev/null +++ b/src/minimax/envs/environment.py @@ -0,0 +1,201 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This file extends the Environment class from +https://github.com/RobertTLange/gymnax/ + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import jax +import chex +from typing import Tuple, Union, Optional +from functools import partial +from flax import struct + + +@struct.dataclass +class EnvState: + time: int + + +@struct.dataclass +class EnvParams: + max_episode_steps: int + + +class Environment(object): + """Jittable abstract base class for all basic environments.""" + + def __init__(self): + self.eval_solved_rate = self.get_eval_solved_rate_fn() + + @property + def default_params(self) -> EnvParams: + return EnvParams() + + @staticmethod + def align_kwargs(kwargs, other_kwargs): + """ + Return kwargs that are consistent with other_kwargs, + e.g. in the case of the student env, other_kwargs may be + those for the paired teacher env, and in the case of the + teacher env, the paired student env. + """ + raise NotImplementedError + + @partial(jax.jit, static_argnums=(0, 4)) + def step( + self, + key: chex.PRNGKey, + state: EnvState, + action: Union[int, float], + reset_on_done: bool = True, + reset_state: Optional[chex.ArrayTree] = None, + ) -> Tuple[chex.ArrayTree, EnvState, float, bool]: + """Performs step transitions in the environment.""" + # Use default env parameters if no others specified + if hasattr(self, 'params'): + params = self.params + else: + params = self.default_params + + key, key_reset = jax.random.split(key) + obs_st, state_st, reward, done, info = self.step_env( + key, state, action + ) + + if reset_on_done: + if reset_state is not None: + state_re = reset_state + obs_re = self.get_obs(reset_state) + else: + if hasattr(params, 'singleton_seed') \ + and params.singleton_seed >= 0: + key_reset = jax.random.PRNGKey(params.singleton_seed) + + obs_re, state_re = self.reset_env(key_reset) + + # Auto-reset environment based on termination + if type(done) == dict: + # Multi Agent setting + done = done["__all__"] + + state = jax.tree_map( + lambda x, y: jax.lax.select(done, x, y), state_re, state_st + ) + obs = jax.tree_map( + lambda x, y: jax.lax.select(done, x, y), obs_re, obs_st + ) + else: + obs, state = obs_st, state_st + + return obs, state, reward, done, info + + @partial(jax.jit, static_argnums=(0,)) + def reset( + self, + key: chex.PRNGKey, + ) -> Tuple[chex.ArrayTree, EnvState]: + """Performs resetting of environment.""" + # Use default env parameters if no others specified + if hasattr(self, 'params'): + params = self.params + else: + params = self.default_params + + if hasattr(params, 'singleton_seed') \ + and params.singleton_seed >= 0: + key = jax.random.PRNGKey(params.singleton_seed) + obs, state = self.reset_env(key) + return obs, state + + def step_env( + self, + key: chex.PRNGKey, + state: EnvState, + action: Union[int, float], + ) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]: + """Environment-specific step transition.""" + raise NotImplementedError + + def reset_env( + self, key: chex.PRNGKey + ) -> Tuple[chex.ArrayTree, EnvState]: + """Environment-specific reset.""" + raise NotImplementedError + + def set_state( + self, + state: EnvState + ) -> Tuple[chex.ArrayTree, EnvState]: + """ + Implemented for basic envs. + """ + return self.get_obs(state), state + + def set_env_instance( + self, + encoding: chex.ArrayTree + ) -> Tuple[chex.ArrayTree, EnvState]: + """ + Implemented for basic envs. + """ + raise NotImplementedError + + def get_env_instance( + self, + key: chex.PRNGKey, + state: EnvState + ) -> chex.ArrayTree: + """ + Implemented for UED envs. + """ + raise NotImplementedError + + def get_obs(self, state: EnvState) -> chex.ArrayTree: + """Applies observation function to state.""" + raise NotImplementedError + + def is_terminal(self, state: EnvState) -> bool: + """Check whether state is terminal.""" + raise NotImplementedError + + def get_eval_solved_rate_fn(self): + return None + + @property + def name(self) -> str: + """Environment name.""" + return type(self).__name__ + + @property + def num_actions(self) -> int: + """Number of actions possible in environment.""" + raise NotImplementedError + + def action_space(self): + """Action space of the environment.""" + raise NotImplementedError + + def observation_space(self): + """Observation space of the environment.""" + raise NotImplementedError + + def state_space(self): + """State space of the environment.""" + raise NotImplementedError + + def max_episode_steps(self): + """Maximum number of time steps in environment.""" + raise NotImplementedError + + def get_env_metrics(self, state: EnvState): + """Environment-specific metrics, e.g. number of walls.""" + raise NotImplementedError diff --git a/src/minimax/envs/environment_ued.py b/src/minimax/envs/environment_ued.py new file mode 100644 index 0000000..bab01ff --- /dev/null +++ b/src/minimax/envs/environment_ued.py @@ -0,0 +1,142 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import jax +import chex +from typing import Tuple, Union, Optional +from functools import partial +from flax import struct + +import jax.numpy as jnp + +from .environment import EnvParams, EnvState, Environment + + +@struct.dataclass +class UEDEnvState: + encoding: chex.Array + time: int + terminal: bool + + +class UEDEnvironment: + """ + Wraps two Environment instances, one being the basic environment, + and the other, its UED counterpart. + + The interface extends the student environment interace. + """ + + def __init__(self, env, ued_env): + self.env = env + self.ued_env = ued_env + + # Default reset and step centers on student + self.reset = self.reset_random + self.step = self.env.step + + def reset_random( + self, + rng: chex.PRNGKey, + ) -> Tuple[chex.ArrayTree, EnvState]: + return self.env.reset(rng) + + def get_monitored_metrics(self): + return self.env.get_monitored_metrics() + self.ued_env.get_monitored_metrics() + + def reset_teacher( + self, + rng: chex.PRNGKey, + ) -> Tuple[chex.ArrayTree, EnvState]: + return self.ued_env.reset(rng) + + def step_teacher( + self, + rng: chex.PRNGKey, + ued_state: EnvState, + action: Union[int, float], + ) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]: + return self.ued_env.step( + rng, ued_state, action, reset_on_done=False) + + def reset_student( + self, + rng: chex.PRNGKey, + ued_state: EnvState, + ) -> Tuple[chex.ArrayTree, EnvState]: + """ + Reset the student based on + """ + # ued_state_ = UEDEnvState( + # encoding=jnp.array([17, 6, 3, 23, 4, 21, 2, 3, 16, 12, 9], dtype=jnp.uint32), time=jnp.array(11, dtype=jnp.uint32), terminal=jnp.array(True)) + encoding = self.ued_env.get_env_instance(rng, ued_state) + env = self.env.set_env_instance(encoding) + return env + + def step_student( + self, + rng: chex.PRNGKey, + state: EnvState, + action: Union[int, float], + reset_state: Optional[chex.ArrayTree] = None + ) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]: + return self.env.step( + rng, + state, + action, + reset_state=reset_state) + + def set_env_instance(self, encoding: chex.ArrayTree): + return self.env.set_env_instance(encoding) + + # Spaces interface + def action_space(self): + """Action space of the environment.""" + return self.env.action_space() + + def observation_space(self): + """Observation space of the environment.""" + return self.env.observation_space() + + def state_space(self): + """Observation space of the environment.""" + return self.env.state_space() + + def max_episode_steps(self): + """Action space of the environment.""" + return self.env.max_episode_steps() + + def ued_action_space(self): + """Action space of the environment.""" + return self.ued_env.action_space() + + def ued_observation_space(self): + """Observation space of the environment.""" + return self.ued_env.observation_space() + + def ued_state_space(self): + """Observation space of the environment.""" + return self.ued_env.state_space() + + def ued_max_episode_steps(self): + """Action space of the environment.""" + return self.ued_env.max_episode_steps() + + def get_env_metrics(self, state: EnvState): + """Environment-specific metrics, e.g. number of walls.""" + return self.env.get_env_metrics(state) + + @property + def agents(self) -> str: + """Environment name.""" + return self.env.agents + + @property + def name(self) -> str: + """Environment name.""" + return self.env.name diff --git a/src/minimax/envs/interactive/__init__.py b/src/minimax/envs/interactive/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/minimax/envs/interactive/manual_ctrl_maze.py b/src/minimax/envs/interactive/manual_ctrl_maze.py new file mode 100644 index 0000000..d20554b --- /dev/null +++ b/src/minimax/envs/interactive/manual_ctrl_maze.py @@ -0,0 +1,219 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import time +import argparse +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np + +import minimax.envs as envs +from minimax.envs.maze.maze import Actions +from minimax.envs.viz.grid_viz import GridVisualizer + + +def redraw(state, obs, extras): + if extras['is_ued_maze']: + env_instance = extras['env'].get_env_instance(None, state) + maze_map = extras['render_env'].set_env_instance(env_instance)[1].maze_map + extras['viz'].render(extras['params'], state, highlight=False, maze_map=maze_map) + else: + extras['viz'].render(extras['params'], state) + if extras['obs_viz'] is not None: + extras['obs_viz'].render_grid(np.asarray(obs['image']), k_rot90=3, agent_dir_idx=3) + +def reset(key, env, extras): + key, subkey = jax.random.split(extras['rng']) + obs, state = extras['jit_reset'](subkey) + + extras['rng'] = key + extras['obs'] = obs + extras['state'] = state + extras['n'] += 1 + + if not extras['is_ued_maze']: + metrics = env.get_env_metrics(state) + print(metrics) + extras['n_walls_total'] += metrics['n_walls'] + + if not extras['is_ued_maze']: + print(f"mean walls: {extras['n_walls_total']/extras['n']}", flush=True) + + redraw(state, obs, extras) + +def step(env, action, extras): + key, subkey = jax.random.split(extras['rng']) + obs, state, reward, done, info = env.step_env(subkey, extras['state'], action) + extras['obs'] = obs + extras['state'] = state + # print(f"reward={reward}, agent_dir={obs['agent_dir']}") + print(f"reward={reward}") + + if done or action == Actions.done: + key, subkey = jax.random.split(subkey) + reset(subkey, env, extras) + else: + redraw(state, obs, extras) + + extras['rng'] = key + + +def key_handler(env, extras, event): + print('pressed', event.key) + + if event.key == 'escape': + window.close() + return + + if event.key == 'backspace': + extras['jit_reset']((env, extras)) + return + + if event.key == 'left': + step(env, Actions.left, extras) + return + if event.key == 'right': + step(env, Actions.right, extras) + return + if event.key == 'up': + step(env, Actions.forward, extras) + return + + # Spacebar + if event.key == ' ': + step(env, Actions.toggle, extras) + return + if event.key == 'a': + step(env, Actions.pickup, extras) + return + if event.key == 'd': + step(env, Actions.drop, extras) + return + + if event.key == 'enter': + step(env, Actions.done, extras) + return + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--env", + type=str, + help="Environment name", + default="Maze" + ) + parser.add_argument( + "--seed", + type=int, + help="random seed to generate the environment with", + default=90 + ) + parser.add_argument( + '--render_agent_view', + default=False, + help="draw the agent sees (partially observable view)", + action='store_true' + ) + parser.add_argument( + '--height', + default=13, + type=int, + help="height", + ) + parser.add_argument( + '--width', + default=13, + type=int, + help="width", + ) + parser.add_argument( + '--n_walls', + default=10, + type=int, + help="Number of walls", + ) + parser.add_argument( + '--agent_view_size', + default=5, + type=int, + help="Number of walls", + ) + parser.add_argument( + "--screenshot_path", + type=str, + default=None, + help="maze.png", + ) + args = parser.parse_args() + + kwargs = dict( + height=args.height, + width=args.width, + n_walls=args.n_walls, + agent_view_size=args.agent_view_size, + see_through_walls=True, + see_agent=True, + normalize_obs=False, + sample_n_walls=False, + replace_wall_pos=False, + max_episode_steps=250, + ) + kwargs = {} + env, params = envs.make(args.env, kwargs) + params = env.params + + is_ued_maze = False + render_env = None + if args.env.startswith('UEDMaze'): + is_ued_maze = True + render_env, _ = envs.make('Maze', kwargs) + + viz = GridVisualizer() + obs_viz = None + if args.render_agent_view: + obs_viz = GridVisualizer() + + with jax.disable_jit(False): + jit_reset = jax.jit(env.reset_env, static_argnums=(1,)) + key = jax.random.PRNGKey(args.seed) + key, subkey = jax.random.split(key) + o0, s0 = jit_reset(subkey) + if is_ued_maze: + maze_map = render_env.set_env_instance(env.get_env_instance(None, s0))[1].maze_map + viz.render(params, s0, highlight=False, maze_map=maze_map) + else: + viz.render(params, s0) + if obs_viz is not None: + obs_viz.render_grid(np.asarray(o0['image']), k_rot90=3, agent_dir_idx=3) + + key, subkey = jax.random.split(key) + extras = { + 'rng': subkey, + 'state': s0, + 'obs': o0, + 'params':params, + 'viz': viz, + 'obs_viz': obs_viz, + 'jit_reset': jit_reset, + 'n_walls_total': 0, + 'n': 0, + 'env': env, + 'render_env': render_env, + 'is_ued_maze': is_ued_maze + } + + if args.screenshot_path is not None: + print('saving') + viz.screenshot(args.screenshot_path) + + viz.window.reg_key_handler(partial(key_handler, env, extras)) + viz.show(block=True) + diff --git a/src/minimax/envs/maze/__init__.py b/src/minimax/envs/maze/__init__.py new file mode 100644 index 0000000..ddc8a93 --- /dev/null +++ b/src/minimax/envs/maze/__init__.py @@ -0,0 +1,14 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .maze import Maze +from .maze_ued import UEDMaze +from .maze_ood import * + +from .maze_comparators import * +from .maze_mutators import * \ No newline at end of file diff --git a/src/minimax/envs/maze/common.py b/src/minimax/envs/maze/common.py new file mode 100644 index 0000000..3965887 --- /dev/null +++ b/src/minimax/envs/maze/common.py @@ -0,0 +1,109 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy + +import numpy as np +import jax.numpy as jnp +from flax import struct +import chex + + +OBJECT_TO_INDEX = { + "unseen": 0, + "empty": 1, + "wall": 2, + "floor": 3, + "door": 4, + "key": 5, + "ball": 6, + "box": 7, + "goal": 8, + "lava": 9, + "agent": 10, +} + + +COLORS = { + 'red' : np.array([255, 0, 0]), + 'green' : np.array([0, 255, 0]), + 'blue' : np.array([0, 0, 255]), + 'purple': np.array([112, 39, 195]), + 'yellow': np.array([255, 255, 0]), + 'grey' : np.array([100, 100, 100]) +} + + +COLOR_TO_INDEX = { + 'red' : 0, + 'green' : 1, + 'blue' : 2, + 'purple': 3, + 'yellow': 4, + 'grey' : 5, +} + + +# Map of agent direction indices to vectors +DIR_TO_VEC = jnp.array([ + # Pointing right (positive X) + (1, 0), # right + (0, 1), # down + (-1, 0), # left + (0, -1), # up +], dtype=jnp.int8) + + +@struct.dataclass +class EnvInstance: + agent_pos: chex.Array + agent_dir_idx: int + goal_pos: chex.Array + wall_map: chex.Array + + +def make_maze_map( + params, + wall_map, + goal_pos, + agent_pos, + agent_dir_idx, + pad_obs=False): + # Expand maze map to H x W x C + empty = jnp.array([OBJECT_TO_INDEX['empty'], 0, 0], dtype=jnp.uint8) + wall = jnp.array([OBJECT_TO_INDEX['wall'], COLOR_TO_INDEX['grey'], 0], dtype=jnp.uint8) + maze_map = jnp.array(jnp.expand_dims(wall_map, -1), dtype=jnp.uint8) + maze_map = jnp.where(maze_map > 0, wall, empty) + + agent = jnp.array([OBJECT_TO_INDEX['agent'], COLOR_TO_INDEX['red'], agent_dir_idx], dtype=jnp.uint8) + agent_x,agent_y = agent_pos + maze_map = maze_map.at[agent_y,agent_x,:].set(agent) + + goal = jnp.array([OBJECT_TO_INDEX['goal'], COLOR_TO_INDEX['green'], 0], dtype=jnp.uint8) + goal_x,goal_y = goal_pos + maze_map = maze_map.at[goal_y,goal_x,:].set(goal) + + # Add observation padding + if pad_obs: + padding = params.agent_view_size-1 + else: + padding = 1 + + maze_map_padded = jnp.tile(wall.reshape((1,1,*empty.shape)), (maze_map.shape[0]+2*padding, maze_map.shape[1]+2*padding, 1)) + maze_map_padded = maze_map_padded.at[padding:-padding,padding:-padding,:].set(maze_map) + + # Add surrounding walls + wall_start = padding-1 # start index for walls + wall_end_y = maze_map_padded.shape[0] - wall_start - 1 + wall_end_x = maze_map_padded.shape[1] - wall_start - 1 + maze_map_padded = maze_map_padded.at[wall_start,wall_start:wall_end_x+1,:].set(wall) # top + maze_map_padded = maze_map_padded.at[wall_end_y,wall_start:wall_end_x+1,:].set(wall) # bottom + maze_map_padded = maze_map_padded.at[wall_start:wall_end_y+1,wall_start,:].set(wall) # left + maze_map_padded = maze_map_padded.at[wall_start:wall_end_y+1,wall_end_x,:].set(wall) # right + + return maze_map_padded diff --git a/src/minimax/envs/maze/maze.py b/src/minimax/envs/maze/maze.py new file mode 100644 index 0000000..1c7f9c2 --- /dev/null +++ b/src/minimax/envs/maze/maze.py @@ -0,0 +1,521 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from dataclasses import dataclass +from collections import namedtuple, OrderedDict +from functools import partial +from enum import IntEnum + +import numpy as np +import jax +import jax.numpy as jnp +from jax import lax +from typing import Tuple, Optional +import chex +from flax import struct +from flax.core.frozen_dict import FrozenDict + +from minimax.envs import environment, spaces +from minimax.envs.registration import register +import minimax.util.graph as _graph_util +from .common import ( + OBJECT_TO_INDEX, + COLORS, + COLOR_TO_INDEX, + DIR_TO_VEC, + EnvInstance, + make_maze_map) + + +class Actions(IntEnum): + # Turn left, turn right, move forward + left = 0 + right = 1 + forward = 2 + + # Pick up an object + pickup = 3 + # Drop an object + drop = 4 + # Toggle/activate an object + toggle = 5 + + # Done completing task + done = 6 + + +@struct.dataclass +class EnvState: + agent_pos: chex.Array + agent_dir: chex.Array + agent_dir_idx: int + goal_pos: chex.Array + wall_map: chex.Array + maze_map: chex.Array + time: int + terminal: bool + + +@struct.dataclass +class EnvParams: + height: int = 15 + width: int = 15 + n_walls: int = 25 + agent_view_size: int = 5 + replace_wall_pos: bool = False + see_through_walls: bool = True + see_agent: bool = False + normalize_obs: bool = False + sample_n_walls: bool = False # Sample n_walls uniformly in [0, n_walls] + obs_agent_pos: bool = False + max_episode_steps: int = 250 + singleton_seed: int = -1, + + +class Maze(environment.Environment): + def __init__( + self, + height=13, + width=13, + n_walls=25, + agent_view_size=5, + replace_wall_pos=False, + see_through_walls=True, + see_agent=False, + max_episode_steps=250, + normalize_obs=False, + sample_n_walls=False, + obs_agent_pos=False, + singleton_seed=-1 + ): + super().__init__() + + self.obs_shape = (agent_view_size, agent_view_size, 3) + self.action_set = jnp.array([ + Actions.left, + Actions.right, + Actions.forward, + Actions.pickup, + Actions.drop, + Actions.toggle, + Actions.done + ]) + + self.params = EnvParams( + height=height, + width=width, + n_walls=n_walls, + agent_view_size=agent_view_size, + replace_wall_pos=replace_wall_pos and not sample_n_walls, + see_through_walls=see_through_walls, + see_agent=see_agent, + max_episode_steps=max_episode_steps, + normalize_obs=normalize_obs, + sample_n_walls=sample_n_walls, + obs_agent_pos=obs_agent_pos, + singleton_seed=-1, + ) + + @property + def default_params(self) -> EnvParams: + # Default environment parameters + return EnvParams() + + def step_env( + self, + key: chex.PRNGKey, + state: EnvState, + action: int, + ) -> Tuple[chex.Array, EnvState, float, bool, dict]: + """Perform single timestep state transition.""" + a = self.action_set[action] + state, reward = self.step_agent(key, state, a) + # Check game condition & no. steps for termination condition + state = state.replace(time=state.time + 1) + done = self.is_terminal(state) + state = state.replace(terminal=done) + + return ( + lax.stop_gradient(self.get_obs(state)), + lax.stop_gradient(state), + reward.astype(jnp.float32), + done, + {}, + ) + + def reset_env( + self, + key: chex.PRNGKey, + ) -> Tuple[chex.Array, EnvState]: + """Reset environment state by resampling contents of maze_map + - initial agent position + - goal position + - wall positions + """ + params = self.params + h = params.height + w = params.width + all_pos = np.arange(np.prod([h, w]), dtype=jnp.uint32) + + # Reset wall map, with shape H x W, and value of 1 at (i,j) iff there is a wall at (i,j) + key, subkey = jax.random.split(key) + wall_idx = jax.random.choice( + subkey, all_pos, + shape=(params.n_walls,), + replace=params.replace_wall_pos) + + if params.sample_n_walls: + key, subkey = jax.random.split(key) + sampled_n_walls = jax.random.randint( + subkey, (), minval=0, maxval=params.n_walls) + sample_wall_mask = jnp.arange(params.n_walls) < sampled_n_walls + dummy_wall_idx = wall_idx.at[0].get().repeat(params.n_walls) + wall_idx = jax.lax.select( + sample_wall_mask, + wall_idx, + dummy_wall_idx + ) + + occupied_mask = jnp.zeros_like(all_pos) + occupied_mask = occupied_mask.at[wall_idx].set(1) + wall_map = occupied_mask.reshape(h, w).astype(jnp.bool_) + + # Reset agent position + dir + key, subkey = jax.random.split(key) + agent_idx = jax.random.choice(subkey, all_pos, shape=(1,), p=( + ~occupied_mask.astype(jnp.bool_)).astype(jnp.float32)) + occupied_mask = occupied_mask.at[agent_idx].set(1) + agent_pos = jnp.array([agent_idx % w, agent_idx//w], + dtype=jnp.uint32).flatten() + + key, subkey = jax.random.split(key) + agent_dir_idx = jax.random.choice( + subkey, jnp.arange(len(DIR_TO_VEC), dtype=jnp.uint8)) + agent_dir = DIR_TO_VEC.at[agent_dir_idx].get() + + # Reset goal position + key, subkey = jax.random.split(key) + goal_idx = jax.random.choice(subkey, all_pos, shape=(1,), p=( + ~occupied_mask.astype(jnp.bool_)).astype(jnp.float32)) + goal_pos = jnp.array([goal_idx % w, goal_idx//w], + dtype=jnp.uint32).flatten() + + maze_map = make_maze_map( + params, + wall_map, + goal_pos, + agent_pos, + agent_dir_idx, + pad_obs=True) + + state = EnvState( + agent_pos=agent_pos, + agent_dir=agent_dir, + agent_dir_idx=agent_dir_idx, + goal_pos=goal_pos, + wall_map=wall_map.astype(jnp.bool_), + maze_map=maze_map, + time=0, + terminal=False, + ) + + return self.get_obs(state), state + + def set_env_instance( + self, + encoding: EnvInstance): + """ + Instance is encoded as a PyTree containing the following fields: + agent_pos, agent_dir, goal_pos, wall_map + """ + params = self.params + agent_pos = encoding.agent_pos + agent_dir_idx = encoding.agent_dir_idx + + agent_dir = DIR_TO_VEC.at[agent_dir_idx].get() + goal_pos = encoding.goal_pos + wall_map = encoding.wall_map + maze_map = make_maze_map( + params, + wall_map, + goal_pos, + agent_pos, + agent_dir_idx, # ued instances include wall padding + pad_obs=True) + + state = EnvState( + agent_pos=agent_pos, + agent_dir=agent_dir, + agent_dir_idx=agent_dir_idx, + goal_pos=goal_pos, + wall_map=wall_map, + maze_map=maze_map, + time=0, + terminal=False + ) + + return self.get_obs(state), state + + def get_obs(self, state: EnvState) -> chex.Array: + """Return limited grid view ahead of agent.""" + obs = jnp.zeros(self.obs_shape, dtype=jnp.uint8) + + agent_x, agent_y = state.agent_pos + + obs_fwd_bound1 = state.agent_pos + obs_fwd_bound2 = state.agent_pos + \ + state.agent_dir*(self.obs_shape[0]-1) + + side_offset = self.obs_shape[0]//2 + obs_side_bound1 = state.agent_pos + (state.agent_dir == 0)*side_offset + obs_side_bound2 = state.agent_pos - (state.agent_dir == 0)*side_offset + + all_bounds = jnp.stack( + [obs_fwd_bound1, obs_fwd_bound2, obs_side_bound1, obs_side_bound2]) + + # Clip obs to grid bounds appropriately + padding = obs.shape[0]-1 + obs_bounds_min = np.min(all_bounds, 0) + padding + obs_range_x = jnp.arange(obs.shape[0]) + obs_bounds_min[1] + obs_range_y = jnp.arange(obs.shape[0]) + obs_bounds_min[0] + + meshgrid = jnp.meshgrid(obs_range_y, obs_range_x) + coord_y = meshgrid[1].flatten() + coord_x = meshgrid[0].flatten() + + obs = state.maze_map.at[ + coord_y, coord_x, :].get().reshape(obs.shape[0], obs.shape[1], 3) + + obs = (state.agent_dir_idx == 0)*jnp.rot90(obs, 1) + \ + (state.agent_dir_idx == 1)*jnp.rot90(obs, 2) + \ + (state.agent_dir_idx == 2)*jnp.rot90(obs, 3) + \ + (state.agent_dir_idx == 3)*jnp.rot90(obs, 4) + + if not self.params.see_agent: + obs = obs.at[-1, side_offset].set( + jnp.array([OBJECT_TO_INDEX['empty'], 0, 0], dtype=jnp.uint8) + ) + + if not self.params.see_through_walls: + pass + + image = obs.astype(jnp.uint8) + if self.params.normalize_obs: + image = image/10.0 + + obs_dict = dict( + image=image, + agent_dir=state.agent_dir_idx + ) + if self.params.obs_agent_pos: + obs_dict.update(dict(agent_pos=state.agent_pos)) + + return OrderedDict(obs_dict) + + def step_agent(self, key: chex.PRNGKey, state: EnvState, action: int) -> Tuple[EnvState, float]: + params = self.params + + # Update agent position (forward action) + fwd_pos = jnp.minimum( + jnp.maximum(state.agent_pos + (action == + Actions.forward)*state.agent_dir, 0), + jnp.array((params.width-1, params.height-1), dtype=jnp.uint32)) + + # Can't go past wall or goal + fwd_pos_has_wall = state.wall_map.at[fwd_pos[1], fwd_pos[0]].get() + fwd_pos_has_goal = jnp.logical_and( + fwd_pos[0] == state.goal_pos[0], fwd_pos[1] == state.goal_pos[1]) + + fwd_pos_blocked = jnp.logical_or(fwd_pos_has_wall, fwd_pos_has_goal) + + agent_pos_prev = jnp.array(state.agent_pos) + agent_pos = (fwd_pos_blocked*state.agent_pos + + (~fwd_pos_blocked)*fwd_pos).astype(jnp.uint32) + + # Update agent direction (left_turn or right_turn action) + agent_dir_offset = \ + 0 \ + + (action == Actions.left)*(-1) \ + + (action == Actions.right)*1 + + agent_dir_idx = (state.agent_dir_idx + agent_dir_offset) % 4 + agent_dir = DIR_TO_VEC[agent_dir_idx] + + # Update agent component in maze_map + empty = jnp.array([OBJECT_TO_INDEX['empty'], 0, 0], dtype=jnp.uint8) + agent = jnp.array( + [OBJECT_TO_INDEX['agent'], COLOR_TO_INDEX['red'], agent_dir_idx], dtype=jnp.uint8) + padding = self.obs_shape[0]-1 + maze_map = state.maze_map + maze_map = maze_map.at[padding+agent_pos_prev[1], + padding+agent_pos_prev[0], :].set(empty) + maze_map = maze_map.at[padding+agent_pos[1], + padding+agent_pos[0], :].set(agent) + + # Return reward + # rng = jax.random.PRNGKey(agent_dir_idx + agent_pos[0] + agent_pos[1]) + # rand_reward = jax.random.uniform(rng) + reward = (1.0 - 0.9*((state.time+1)/params.max_episode_steps) + )*fwd_pos_has_goal # rand_reward + + return ( + state.replace( + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + agent_dir=agent_dir, + maze_map=maze_map, + terminal=fwd_pos_has_goal), + reward + ) + + def is_terminal(self, state: EnvState) -> bool: + """Check whether state is terminal.""" + done_steps = state.time >= self.params.max_episode_steps + return jnp.logical_or(done_steps, state.terminal) + + def get_eval_solved_rate_fn(self): + def _fn(ep_stats): + return ep_stats['return'] > 0 + + return _fn + + @property + def name(self) -> str: + """Environment name.""" + return "Maze" + + @property + def num_actions(self) -> int: + """Number of actions possible in environment.""" + return len(self.action_set) + + def action_space(self) -> spaces.Discrete: + """Action space of the environment.""" + return spaces.Discrete( + len(self.action_set), + dtype=jnp.uint32 + ) + + def observation_space(self) -> spaces.Dict: + """Observation space of the environment.""" + spaces_dict = { + 'image': spaces.Box(0, 255, self.obs_shape), + 'agent_dir': spaces.Discrete(4) + } + if self.params.obs_agent_pos: + params = self.params + h = params.height + w = params.width + spaces_dict.update({'agent_pos': spaces.Box( + 0, max(w, h), (2,), dtype=jnp.uint32)}) + + return spaces.Dict(spaces_dict) + + def get_monitored_metrics(self): + return () + + def state_space(self) -> spaces.Dict: + """State space of the environment.""" + params = self.params + h = params.height + w = params.width + agent_view_size = params.agent_view_size + return spaces.Dict({ + "agent_pos": spaces.Box(0, max(w, h), (2,), dtype=jnp.uint32), + "agent_dir": spaces.Discrete(4), + "goal_pos": spaces.Box(0, max(w, h), (2,), dtype=jnp.uint32), + "maze_map": spaces.Box(0, 255, (w + agent_view_size, h + agent_view_size, 3), dtype=jnp.uint32), + "time": spaces.Discrete(params.max_episode_steps), + "terminal": spaces.Discrete(2), + }) + + def max_episode_steps(self) -> int: + return self.params.max_episode_steps + + def get_env_metrics(self, state: EnvState) -> dict: + n_walls = state.wall_map.sum() + shortest_path_length = _graph_util.shortest_path_len( + state.wall_map, + state.agent_pos, + state.goal_pos + ) + + return dict( + n_walls=n_walls, + shortest_path_length=shortest_path_length, + passable=shortest_path_length > 0, + ) + + +# Register the env +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +register(env_id='Maze', entry_point=module_path + ':Maze') + + +if __name__ == '__main__': + from minimax.envs.wrappers import MonitorReturnWrapper + + render = False + n_envs = 16384 + + if render: + from minimax.envs.viz.grid_viz import GridVisualizer + viz = GridVisualizer() + obs_viz = GridVisualizer() + + viz.show() + obs_viz.show() + + kwargs = dict( + max_episode_steps=250, + height=15, + width=15, + n_walls=25, + agent_view_size=5, + see_through_walls=True + ) + env = MonitorReturnWrapper(Maze(**kwargs)) + params = env.params + extra = env.reset_extra() + + jit_reset_env = jax.jit(env.reset) + jit_step_env = jax.jit(env.step) + + key = jax.random.PRNGKey(0) + vrngs = jax.random.split(key, n_envs) + key, subkey = jax.random.split(jax.random.PRNGKey(0)) + obs, state, extra = jax.vmap( + jit_reset_env, in_axes=(0))(vrngs) + + all_sps = [] + import time + for i in range(1000): + print('step', i) + key, subkey = jax.random.split(key) + vrngs = jax.random.split(subkey, n_envs) + start = time.time() + obs, state, reward, done, info, extra = jax.vmap(jit_step_env)( + vrngs, state, action=jax.vmap(env.action_space().sample, in_axes=(0))(vrngs), extra=extra) + obs['image'].block_until_ready() + end = time.time() + # print(f"sps: {1/(end-start) * n_envs}") + # print('return:', info['return']) + + all_sps.append(1/(end-start) * n_envs) + + if render: + viz.render(params, state) + obs_viz.render_grid(np.asarray(env.get_obs( + state)['image']), k_rot90=0, agent_dir_idx=3) + + print('mean sps:', np.mean(all_sps)) + print('std sps:', np.std(all_sps)) diff --git a/src/minimax/envs/maze/maze_comparators.py b/src/minimax/envs/maze/maze_comparators.py new file mode 100644 index 0000000..5a2aa64 --- /dev/null +++ b/src/minimax/envs/maze/maze_comparators.py @@ -0,0 +1,34 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import numpy as np +import jax +import jax.numpy as jnp + +from minimax.envs.registration import register_comparator + + +@jax.jit +def is_equal_map(a, b): + agent_pos_eq = jnp.equal(a.agent_pos, b.agent_pos).all() + goal_pos_eq = jnp.equal(a.goal_pos, b.goal_pos).all() + wall_map_eq = jnp.equal(a.wall_map, b.wall_map).all() + + _eq = jnp.logical_and(agent_pos_eq, goal_pos_eq) + _eq = jnp.logical_and(_eq, wall_map_eq) + + return _eq + + +# Register the mutators +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +register_comparator(env_id='Maze', comparator_id=None, entry_point=module_path + ':is_equal_map') \ No newline at end of file diff --git a/src/minimax/envs/maze/maze_mutators.py b/src/minimax/envs/maze/maze_mutators.py new file mode 100644 index 0000000..47907e4 --- /dev/null +++ b/src/minimax/envs/maze/maze_mutators.py @@ -0,0 +1,110 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from enum import IntEnum + +import numpy as np +import jax +import jax.numpy as jnp + +from .common import make_maze_map +from minimax.envs.registration import register_mutator + + +class Mutations(IntEnum): + # Turn left, turn right, move forward + NO_OP = 0 + FLIP_WALL = 1 + MOVE_GOAL = 2 + + +def flip_wall(rng, state): + wall_map = state.wall_map + h,w = wall_map.shape + wall_mask = jnp.ones((h*w,), dtype=jnp.bool_) + + goal_idx = w*state.goal_pos[1] + state.goal_pos[0] + agent_idx = w*state.agent_pos[1] + state.agent_pos[0] + wall_mask = wall_mask.at[goal_idx].set(False) + wall_mask = wall_mask.at[agent_idx].set(False) + + flip_idx = jax.random.choice(rng, np.arange(h*w), p=wall_mask) + flip_y = flip_idx//w + flip_x = flip_idx%w + + flip_val = ~wall_map.at[flip_y,flip_x].get() + next_wall_map = wall_map.at[flip_y,flip_x].set(flip_val) + + return state.replace(wall_map=next_wall_map) + + +def move_goal(rng, state): + wall_map = state.wall_map + h,w = wall_map.shape + wall_mask = wall_map.flatten() + + goal_idx = w*state.goal_pos[1] + state.goal_pos[0] + agent_idx = w*state.agent_pos[1] + state.agent_pos[0] + wall_mask = wall_mask.at[goal_idx].set(True) + wall_mask = wall_mask.at[agent_idx].set(True) + + next_goal_idx = jax.random.choice(rng, np.arange(h*w), p=~wall_mask) + next_goal_y = next_goal_idx//w + next_goal_x = next_goal_idx%w + + next_wall_map = wall_map.at[next_goal_y,next_goal_x].set(False) + next_goal_pos = jnp.array([next_goal_x,next_goal_y], dtype=jnp.uint32) + + return state.replace(wall_map=next_wall_map, goal_pos=next_goal_pos) + + +@partial(jax.jit, static_argnums=(1,3)) +def move_goal_flip_walls(rng, params, state, n=1): + if n == 0: + return state + + def _mutate(carry, step): + state = carry + rng, mutation = step + + rng, arng, brng = jax.random.split(rng,3) + + is_flip_wall = jnp.equal(mutation, Mutations.FLIP_WALL.value) + mutated_state = flip_wall(arng, state) + next_state = jax.tree_map(lambda x,y: jax.lax.select(is_flip_wall, x, y), mutated_state, state) + + is_move_goal = jnp.equal(mutation, Mutations.MOVE_GOAL.value) + mutated_state = move_goal(brng, state) + next_state = jax.tree_map(lambda x,y: jax.lax.select(is_move_goal, x, y), mutated_state, next_state) + + return next_state, None + + rng, nrng, *mrngs = jax.random.split(rng, n+2) + mutations = jax.random.choice(nrng, np.arange(len(Mutations)), (n,)) + + state, _ = jax.lax.scan(_mutate, state, (jnp.array(mrngs), mutations)) + + # Update state maze_map + next_maze_map = make_maze_map( + params, + state.wall_map, + state.goal_pos, + state.agent_pos, + state.agent_dir_idx, + pad_obs=True) + + return state.replace(maze_map=next_maze_map) + +# Register the mutators +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +register_mutator(env_id='Maze', mutator_id=None, entry_point=module_path + ':move_goal_flip_walls') \ No newline at end of file diff --git a/src/minimax/envs/maze/maze_ood.py b/src/minimax/envs/maze/maze_ood.py new file mode 100644 index 0000000..fa9a125 --- /dev/null +++ b/src/minimax/envs/maze/maze_ood.py @@ -0,0 +1,1111 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Tuple, Optional + +import jax +import jax.numpy as jnp +from flax import struct +import chex + +from minimax.envs.registration import register +from .common import ( + DIR_TO_VEC, + OBJECT_TO_INDEX, + COLOR_TO_INDEX, + make_maze_map, +) +from .maze import ( + Maze, + EnvParams, + EnvState, + Actions +) + +# ======== Singleton mazes ======== +class MazeSingleton(Maze): + def __init__( + self, + height=15, + width=15, + wall_map=None, + goal_pos=None, + agent_pos=None, + agent_dir_idx=None, + agent_view_size=5, + see_through_walls=True, + see_agent=False, + normalize_obs=False, + obs_agent_pos=False, + max_episode_steps=None, + singleton_seed=-1, + ): + super().__init__( + height=height, + width=width, + agent_view_size=agent_view_size, + see_through_walls=see_through_walls, + see_agent=see_agent, + normalize_obs=normalize_obs, + obs_agent_pos=obs_agent_pos, + max_episode_steps=max_episode_steps, + singleton_seed=singleton_seed + ) + + if wall_map is None: + self.wall_map = jnp.zeros((height,width), dtype=jnp.bool_) + else: + self.wall_map = \ + jnp.array( + [[int(x) for x in row.split()] + for row in wall_map], dtype=jnp.bool_) + height, width = self.wall_map.shape + + if max_episode_steps is None: + max_episode_steps = 2*(height+2)*(width+2) # Match original eval steps + + self.goal_pos_choices = None + if goal_pos is None: + self.goal_pos = jnp.array([height, width]) - jnp.ones(2, dtype=jnp.uint32) + elif isinstance(goal_pos, (tuple, list)) \ + and isinstance(goal_pos[0], (tuple, list)): + self.goal_pos_choices = jnp.array(goal_pos, dtype=jnp.uint32) + self.goal_pos = goal_pos[0] + else: + self.goal_pos = jnp.array(goal_pos, dtype=jnp.uint32) + + if agent_pos is None: + self.agent_pos = jnp.zeros(2, dtype=jnp.uint32) + else: + self.agent_pos = jnp.array(agent_pos, dtype=jnp.uint32) + + self.agent_dir_idx = agent_dir_idx + + if self.agent_dir_idx is None: + self.agent_dir_idx = 0 + + self.params = EnvParams( + height=height, + width=width, + agent_view_size=agent_view_size, + see_through_walls=see_through_walls, + see_agent=see_agent, + normalize_obs=normalize_obs, + obs_agent_pos=obs_agent_pos, + max_episode_steps=max_episode_steps, + singleton_seed=-1, + ) + + self.maze_map = make_maze_map( + self.params, + self.wall_map, + self.goal_pos, + self.agent_pos, + self.agent_dir_idx, + pad_obs=True) + + @property + def default_params(self) -> EnvParams: + # Default environment parameters + return EnvParams() + + def reset_env( + self, + key: chex.PRNGKey, + ) -> Tuple[chex.Array, EnvState]: + + if self.agent_dir_idx is None: + key, subkey = jax.random.split(key) + agent_dir_idx = jax.random.choice(subkey, 4) + else: + agent_dir_idx = self.agent_dir_idx + + if self.goal_pos_choices is not None: + key, subkey = jax.random.split(key) + goal_pos = jax.random.choice(subkey, self.goal_pos_choices) + maze_map = make_maze_map( + self.params, + self.wall_map, + goal_pos, + self.agent_pos, + agent_dir_idx, + pad_obs=True) + else: + goal_pos = self.goal_pos + maze_map = self.maze_map + + state = EnvState( + agent_pos=self.agent_pos, + agent_dir=DIR_TO_VEC[agent_dir_idx], + agent_dir_idx=agent_dir_idx, + goal_pos=goal_pos, + wall_map=self.wall_map, + maze_map=maze_map, + time=0, + terminal=False, + ) + + return self.get_obs(state), state + + +# ======== Specific mazes ======== +class SixteenRooms(MazeSingleton): + def __init__( + self, + see_agent=False, + normalize_obs=False): + wall_map = [ + "0 0 0 1 0 0 1 0 0 1 0 0 0", + "0 0 0 0 0 0 0 0 0 1 0 0 0", + "0 0 0 1 0 0 1 0 0 0 0 0 0", + "1 0 1 1 1 0 1 1 0 1 1 1 0", + "0 0 0 1 0 0 0 0 0 0 0 0 0", + "0 0 0 0 0 0 1 0 0 1 0 0 0", + "1 1 0 1 0 1 1 0 1 1 1 0 1", + "0 0 0 1 0 0 0 0 0 1 0 0 0", + "0 0 0 1 0 0 1 0 0 0 0 0 0", + "0 1 1 1 1 0 1 1 0 1 0 1 1", + "0 0 0 1 0 0 1 0 0 1 0 0 0", + "0 0 0 0 0 0 1 0 0 0 0 0 0", + "0 0 0 1 0 0 0 0 0 1 0 0 0" + ] + goal_pos = (11,11) + agent_pos = (1,1) + agent_dir_idx = 0 + + super().__init__( + wall_map=wall_map, + goal_pos=goal_pos, + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + see_agent=see_agent, + normalize_obs=normalize_obs + ) + + +class SixteenRooms2(MazeSingleton): + def __init__( + self, + see_agent=False, + normalize_obs=False): + wall_map = [ + "0 0 0 1 0 0 0 0 0 1 0 0 0", + "0 0 0 0 0 0 1 0 0 1 0 0 0", + "0 0 0 1 0 0 1 0 0 1 0 0 0", + "1 1 1 1 0 1 1 0 1 1 1 0 1", + "0 0 0 1 0 0 1 0 0 0 0 0 0", + "0 0 0 0 0 0 1 0 0 1 0 0 0", + "1 0 1 1 1 1 1 0 1 1 1 1 1", + "0 0 0 1 0 0 1 0 0 1 0 0 0", + "0 0 0 1 0 0 0 0 0 0 0 0 0", + "1 1 0 1 1 0 1 1 0 1 1 1 1", + "0 0 0 1 0 0 1 0 0 1 0 0 0", + "0 0 0 0 0 0 1 0 0 0 0 0 0", + "0 0 0 1 0 0 1 0 0 1 0 0 0" + ] + goal_pos = (11,11) + agent_pos = (1,1) + agent_dir_idx = None + + super().__init__( + wall_map=wall_map, + goal_pos=goal_pos, + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + see_agent=see_agent, + normalize_obs=normalize_obs + ) + + +class Labyrinth(MazeSingleton): + def __init__( + self, + see_agent=False, + normalize_obs=False): + wall_map = [ + "0 0 0 0 0 0 0 0 0 0 0 0 0", + "0 1 1 1 1 1 1 1 1 1 1 1 0", + "0 1 0 0 0 0 0 0 0 0 0 1 0", + "0 1 0 1 1 1 1 1 1 1 0 1 0", + "0 1 0 1 0 0 0 0 0 1 0 1 0", + "0 1 0 1 0 1 1 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 0 0 1 0 0 0 1 0 1 0", + "0 1 1 1 1 1 1 1 1 1 0 1 0", + "0 0 0 0 0 1 0 0 0 0 0 1 0", + "1 1 1 1 0 1 0 1 1 1 1 1 0", + "0 0 0 0 0 1 0 0 0 0 0 0 0" + ] + goal_pos = (6,6) + agent_pos = (0,12) + agent_dir_idx = 0 + + super().__init__( + wall_map=wall_map, + goal_pos=goal_pos, + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + see_agent=see_agent, + normalize_obs=normalize_obs + ) + + +class LabyrinthFlipped(MazeSingleton): + def __init__( + self, + see_agent=False, + normalize_obs=False): + wall_map = [ + '0 0 0 0 0 0 0 0 0 0 0 0 0', + '0 1 1 1 1 1 1 1 1 1 1 1 0', + '0 1 0 0 0 0 0 0 0 0 0 1 0', + '0 1 0 1 1 1 1 1 1 1 0 1 0', + '0 1 0 1 0 0 0 0 0 1 0 1 0', + '0 1 0 1 0 1 1 1 0 1 0 1 0', + '0 1 0 1 0 1 0 1 0 1 0 1 0', + '0 1 0 1 0 1 0 1 0 1 0 1 0', + '0 1 0 1 0 0 0 1 0 0 0 1 0', + '0 1 0 1 1 1 1 1 1 1 1 1 0', + '0 1 0 0 0 0 0 1 0 0 0 0 0', + '0 1 1 1 1 1 0 1 0 1 1 1 1', + '0 0 0 0 0 0 0 1 0 0 0 0 0' + ] + goal_pos = (6,6) + agent_pos = (12,12) + agent_dir_idx = 2 + + super().__init__( + wall_map=wall_map, + goal_pos=goal_pos, + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + see_agent=see_agent, + normalize_obs=normalize_obs + ) + + +class Labyrinth2(MazeSingleton): + def __init__( + self, + see_agent=False, + normalize_obs=False): + wall_map = [ + "0 1 0 0 0 0 0 0 0 0 0 0 0", + "0 1 0 1 1 1 1 1 1 1 1 1 0", + "0 1 0 1 0 0 0 0 0 0 0 1 0", + "0 1 0 1 0 1 1 1 1 1 0 1 0", + "0 1 0 1 0 1 0 0 0 1 0 1 0", + "0 0 0 1 0 1 0 1 0 1 0 1 0", + "1 1 1 1 0 1 0 1 0 1 0 1 0", + "0 0 0 1 0 1 1 1 0 1 0 1 0", + "0 1 0 1 0 0 0 0 0 1 0 1 0", + "0 1 0 1 1 1 1 1 1 1 0 1 0", + "0 1 0 0 0 0 0 0 0 0 0 1 0", + "0 1 1 1 1 1 1 1 1 1 1 1 0", + "0 0 0 0 0 0 0 0 0 0 0 0 0" + ] + goal_pos = (6,6) + agent_pos = (0,0) + agent_dir_idx = None + + super().__init__( + wall_map=wall_map, + goal_pos=goal_pos, + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + see_agent=see_agent, + normalize_obs=normalize_obs + ) + + +class StandardMaze(MazeSingleton): + def __init__( + self, + see_agent=False, + normalize_obs=False): + wall_map = [ + "0 0 0 0 0 1 0 0 0 0 1 0 0", + "0 1 1 1 0 1 1 1 1 0 1 1 0", + "0 1 0 0 0 0 0 0 0 0 0 0 0", + "0 1 1 1 1 1 1 1 1 0 1 1 1", + "0 0 0 0 0 0 0 0 1 0 0 0 0", + "1 1 1 1 1 1 0 1 1 1 1 1 0", + "0 0 0 0 1 0 0 1 0 0 0 0 0", + "0 1 1 0 0 0 1 1 0 1 1 1 1", + "0 0 1 0 1 0 0 1 0 0 0 1 0", + "1 0 1 0 1 1 0 1 1 1 0 1 0", + "1 0 1 0 0 1 0 0 0 1 0 0 0", + "1 0 1 1 0 1 1 1 0 1 1 1 0", + "0 0 0 1 0 0 0 1 0 1 0 0 0" + ] + goal_pos = (6,12) + agent_pos = (6,0) + agent_dir_idx = 0 + + super().__init__( + wall_map=wall_map, + goal_pos=goal_pos, + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + see_agent=see_agent, + normalize_obs=normalize_obs + ) + + +class StandardMaze2(MazeSingleton): + def __init__( + self, + see_agent=False, + normalize_obs=False): + wall_map = [ + "0 0 0 1 0 1 0 0 0 0 1 0 0", + "0 1 0 1 0 1 1 1 1 0 0 0 1", + "0 1 0 0 0 0 0 0 0 0 1 0 0", + "0 1 1 1 1 1 1 1 1 0 1 1 1", + "0 0 0 1 0 0 1 0 1 0 1 0 0", + "1 1 0 1 0 1 1 0 1 0 1 0 0", + "0 1 0 1 0 0 0 0 1 0 1 1 0", + "0 1 0 1 1 0 1 1 1 0 0 1 0", + "0 1 0 0 1 0 0 1 1 1 0 1 0", + "0 1 1 0 1 1 0 1 0 1 0 1 0", + "0 1 0 0 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 0 0 1 0 0 0 1 0 0 0 0 0" + ] + goal_pos = (12,4) + agent_pos = (0,6) + agent_dir_idx = None + + super().__init__( + wall_map=wall_map, + goal_pos=goal_pos, + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + see_agent=see_agent, + normalize_obs=normalize_obs + ) + + +class StandardMaze3(MazeSingleton): + def __init__( + self, + see_agent=False, + normalize_obs=False): + wall_map = [ + "0 0 0 0 1 0 1 0 0 0 0 0 0", + "0 1 1 1 1 0 1 0 1 1 1 1 0", + "0 1 0 0 0 0 1 0 1 0 0 0 0", + "0 0 0 1 1 1 1 0 1 0 1 0 1", + "1 1 0 1 0 0 0 0 1 0 1 0 0", + "0 0 0 1 0 1 1 0 1 0 1 1 0", + "0 1 0 1 0 1 0 0 1 0 0 1 0", + "0 1 0 1 0 1 0 1 1 1 0 1 1", + "0 1 0 0 0 1 0 1 0 1 0 0 0", + "0 1 1 1 0 1 0 1 0 1 1 1 0", + "0 1 0 0 0 1 0 1 0 0 0 1 0", + "0 1 0 1 1 1 0 1 0 1 0 1 0", + "0 1 0 0 0 1 0 0 0 1 0 0 0" + ] + goal_pos = (12,6) + agent_pos = (3,0) + agent_dir_idx = None + + super().__init__( + wall_map=wall_map, + goal_pos=goal_pos, + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + see_agent=see_agent, + normalize_obs=normalize_obs + ) + + +class SmallCorridor(MazeSingleton): + def __init__( + self, + see_agent=False, + normalize_obs=False): + wall_map = [ + "0 0 0 0 0 0 0 0 0 0 0 0 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 1 1 1 1 1 1 1 1 1 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 0 0 0 0 0 0 0 0 0 0 0 0" + ] + goal_pos = [ + (2,5),(4,5),(6,5),(8,5),(10,5), + (2,7),(4,7),(6,7),(8,7),(10,7), + ] + agent_pos = (0,6) + agent_dir_idx = None + + super().__init__( + wall_map=wall_map, + goal_pos=goal_pos, + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + see_agent=see_agent, + normalize_obs=normalize_obs + ) + + +class LargeCorridor(MazeSingleton): + def __init__( + self, + see_agent=False, + normalize_obs=False): + wall_map = [ + "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0 1 0", + "0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0", + ] + goal_pos = [ + (2,8),(4,8),(6,8),(8,8),(10,8),(12,8),(14,8),(16,8), + (2,10),(4,10),(6,10),(8,10),(10,10),(12,10),(14,10),(16,10) + ] + agent_pos = (0,9) + agent_dir_idx = None + + super().__init__( + wall_map=wall_map, + goal_pos=goal_pos, + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + see_agent=see_agent, + normalize_obs=normalize_obs + ) + + +class FourRooms(Maze): + def __init__( + self, + height=17, + width=17, + agent_view_size=5, + see_through_walls=True, + see_agent=False, + normalize_obs=False, + max_episode_steps=250, + singleton_seed=-1): + + super().__init__( + height=height, + width=width, + agent_view_size=agent_view_size, + see_through_walls=see_through_walls, + see_agent=see_agent, + normalize_obs=normalize_obs, + max_episode_steps=max_episode_steps, + singleton_seed=singleton_seed + ) + + assert height % 2 == 1 and width % 2 == 1, \ + 'Grid height and width must be odd' + + wall_map = jnp.zeros((height, width), dtype=jnp.bool_) + wall_map = wall_map.at[height//2, :].set(True) + wall_map = wall_map.at[:, width//2].set(True) + self.wall_map = wall_map + + self.room_h = height//2 + self.room_w = width//2 + + self.all_pos_idxs = jnp.arange(height*width) + self.goal_pos_mask = (~wall_map).flatten() + self.agent_pos_mask = self.goal_pos_mask + + def reset_env( + self, + key: chex.PRNGKey + ) -> Tuple[chex.Array, EnvState]: + # Randomize door positions + params = self.params + + key, x_rng, y_rng = jax.random.split(key,3) + x_door_idxs = jax.random.randint(x_rng, (2,), 0, self.room_w) \ + + jnp.array([0, self.room_w+1], dtype=jnp.uint32) + + y_door_idxs = jax.random.randint(y_rng, (2,), 0, self.room_h) \ + + jnp.array([0, self.room_h+1], dtype=jnp.uint32) + + wall_map = self.wall_map.at[self.room_h, x_door_idxs].set(False) + wall_map = wall_map.at[y_door_idxs,self.room_w].set(False) + + # Randomize goal pos + key, subkey = jax.random.split(key) + goal_pos_idx = jax.random.choice(subkey, self.all_pos_idxs, shape=(), p=self.goal_pos_mask) + goal_pos = jnp.array([goal_pos_idx%params.width, goal_pos_idx//params.width], dtype=jnp.uint32) + + # Randomize agent pos + key, subkey = jax.random.split(key) + agent_pos_mask = self.agent_pos_mask.at[goal_pos_idx].set(False) + agent_pos_idx = jax.random.choice(subkey, self.all_pos_idxs, shape=(), p=self.agent_pos_mask) + agent_pos = jnp.array([agent_pos_idx%params.width, agent_pos_idx//params.width], dtype=jnp.uint32) + + key, subkey = jax.random.split(key) + agent_dir_idx = jax.random.choice(subkey, 4) + + maze_map = make_maze_map( + self.params, + wall_map, + goal_pos, + agent_pos, + agent_dir_idx, + pad_obs=True) + + state = EnvState( + agent_pos=agent_pos, + agent_dir=DIR_TO_VEC[agent_dir_idx], + agent_dir_idx=agent_dir_idx, + goal_pos=goal_pos, + wall_map=wall_map, + maze_map=maze_map, + time=0, + terminal=False, + ) + + return self.get_obs(state), state + + +class Crossing(Maze): + def __init__( + self, + height=9, + width=9, + n_crossings=5, + agent_view_size=5, + see_through_walls=True, + see_agent=False, + normalize_obs=False, + max_episode_steps=250, + singleton_seed=-1): + self.n_crossings = n_crossings + max_episode_steps = 4*(height+2)*(width+2) + + super().__init__( + height=height, + width=width, + agent_view_size=agent_view_size, + see_through_walls=see_through_walls, + see_agent=see_agent, + normalize_obs=normalize_obs, + max_episode_steps=max_episode_steps, + singleton_seed=singleton_seed + ) + + def reset_env( + self, + key: chex.PRNGKey + ) -> Tuple[chex.Array, EnvState]: + params = self.params + height, width = params.height, params.width + goal_pos = jnp.array([width-1, height-1]) + agent_pos = jnp.array([0,0], dtype=jnp.uint32) + agent_dir_idx = 0 + + # Generate walls + wall_map = jnp.zeros((height, width), dtype=jnp.bool_) + + row_y_choices = jnp.arange(1,height-1,2) + col_x_choices = jnp.arange(1,width-1,2) + + rng, subrng = jax.random.split(key) + dirs = jax.random.permutation( + subrng, + jnp.concatenate( + (jnp.zeros(len(row_y_choices)), + jnp.ones(len(col_x_choices))) + ) + )[:self.n_crossings] + + n_v = sum(dirs.astype(jnp.uint32)) + n_h = len(dirs) - n_v + + rng, row_rng, col_rng = jax.random.split(rng, 3) + + row_ys_mask = jax.random.permutation(row_rng, (jnp.arange(len(row_y_choices)) < n_v).repeat(2)) + if height % 2 == 0: + row_ys_mask = jnp.concatenate((row_ys_mask, jnp.zeros(2))) + else: + row_ys_mask = jnp.concatenate((row_ys_mask, jnp.zeros(1))) + + row_ys_mask = jnp.logical_and( + jnp.zeros(height, dtype=jnp.bool_).at[row_y_choices].set(True), + row_ys_mask + ) + + col_xs_mask = jax.random.permutation(col_rng, (jnp.arange(len(col_x_choices)) < n_h).repeat(2)) + if width % 2 == 0: + col_xs_mask = jnp.concatenate((col_xs_mask, jnp.zeros(2))) + else: + col_xs_mask = jnp.concatenate((col_xs_mask, jnp.zeros(1))) + + col_xs_mask = jnp.logical_and( + jnp.zeros(width, dtype=jnp.bool_).at[col_x_choices].set(True), + col_xs_mask + ) + + wall_map = jnp.logical_or( + wall_map, + jnp.tile(jnp.expand_dims(row_ys_mask,-1), (1,width)) + ) + + wall_map = jnp.logical_or( + wall_map, + jnp.tile(jnp.expand_dims(col_xs_mask,0), (height,1)) + ) + + # Generate wall openings + def _scan_step(carry, rng): + wall_map, pos, passed_wall, last_dir, last_dir_idx = carry + + dir_idx = jax.random.randint(rng,(),0,2) + + go_dir = (~passed_wall)*DIR_TO_VEC[dir_idx] + passed_wall*last_dir + next_pos = pos + go_dir + + # If next pos is the right border, force direction to be down + collide = jnp.logical_or( + (next_pos[0] >= width), + (next_pos[1] >= height) + ) + go_dir = collide*DIR_TO_VEC[(dir_idx+1)%2] + (~collide)*go_dir + dir_idx = (dir_idx+1)%2 + (~collide)*dir_idx + + next_pos = collide*(pos + go_dir) + (~collide)*next_pos + + last_dir = go_dir + last_dir_idx = dir_idx + pos = next_pos + + passed_wall = wall_map[pos[1],pos[0]] + wall_map = wall_map.at[pos[1], pos[0]].set(False) + + return (wall_map, pos.astype(jnp.uint32), passed_wall, last_dir, last_dir_idx), None + + n_steps_to_goal = width + height - 2 + rng, *subrngs = jax.random.split(rng, n_steps_to_goal+1) + + pos = agent_pos + passed_wall = jnp.array(False) + last_dir = DIR_TO_VEC[0] + + (wall_map, pos, passed_wall, last_dir, last_dir_idx), _ = jax.lax.scan( + _scan_step, + (wall_map, pos, passed_wall, last_dir, 0), + jnp.array(subrngs), + length=n_steps_to_goal + ) + + maze_map = make_maze_map( + self.params, + wall_map, + goal_pos, + agent_pos, + agent_dir_idx, + pad_obs=True) + + state = EnvState( + agent_pos=agent_pos, + agent_dir=DIR_TO_VEC[agent_dir_idx], + agent_dir_idx=agent_dir_idx, + goal_pos=goal_pos, + wall_map=wall_map, + maze_map=maze_map, + time=0, + terminal=False, + ) + + return self.get_obs(state), state + + +NEIGHBOR_WALL_OFFSETS = jnp.array([ + [1,0], # right + [0,1], # bottom + [-1,0], # left + [0,-1], # top + [0,0] # self +], dtype=jnp.int32) + + +class PerfectMaze(Maze): + def __init__( + self, + height=13, + width=13, + agent_view_size=5, + see_through_walls=True, + see_agent=False, + normalize_obs=False, + max_episode_steps=250, + singleton_seed=-1): + + assert height % 2 == 1 and width % 2 == 1, \ + 'Maze dimensions must be odd.' + + max_episode_steps = 2*(width+2)*(height+2) + super().__init__( + height=height, + width=width, + agent_view_size=agent_view_size, + see_through_walls=see_through_walls, + see_agent=see_agent, + normalize_obs=normalize_obs, + max_episode_steps=max_episode_steps, + singleton_seed=singleton_seed + ) + + def reset_env( + self, + key: chex.PRNGKey + ) -> Tuple[chex.Array, EnvState]: + """ + Generate a perfect maze using an iterated search procedure. + """ + params = self.params + height, width = self.params.height, self.params.width + n_tiles = height*width + + # Track maze wall map + wall_map = jnp.ones((height, width), dtype=jnp.bool_) + + # Track visited, walkable tiles + _h = height//2+1 + _w = width//2+1 + visited_map = jnp.zeros((_h, _w), dtype=jnp.bool_) + vstack = jnp.zeros((_h*_w, 2), dtype=jnp.uint32) + vstack_size = 0 + + # Get initial start tile in walkable index + key, subkey = jax.random.split(key) + start_pos_x = jax.random.randint(subkey, (), 0, _w) + start_pos_y = jax.random.randint(subkey, (), 0, _h) + start_pos = jnp.array([start_pos_x,start_pos_y], dtype=jnp.uint32) + + # Set initial start tile as visited + visited_map = visited_map.at[ + start_pos[1],start_pos[0] + ].set(True) + wall_map = wall_map.at[ + 2*start_pos[1],2*start_pos[0] + ].set(False) + vstack = vstack.at[vstack_size:vstack_size+2].set(start_pos) + vstack_size += 2 + + def _scan_step(carry, key): + # Choose last visited tile and move to a neighbor + wall_map, visited_map, vstack, vstack_size = carry + + abs_pos = 2*vstack[vstack_size-1] + + neighbor_wall_offsets = NEIGHBOR_WALL_OFFSETS.at[-1].set( + vstack[vstack_size-2] - vstack[vstack_size-1] + ) + + # Find a random unvisited neighbor + neighbor_pos = \ + jnp.minimum( + jnp.maximum( + jnp.tile(abs_pos, (len(NEIGHBOR_WALL_OFFSETS),1)) \ + + 2*neighbor_wall_offsets, 0 + ), + jnp.array([width, height], dtype=jnp.uint32) + ) + + # Check for unvisited neighbors. Set self to unvisited if all visited. + neighbor_visited = visited_map.at[ + neighbor_pos[:,1]//2, neighbor_pos[:,0]//2 + ].get() + + n_neighbor_visited = neighbor_visited[:4].sum() + all_visited = n_neighbor_visited == 4 + all_visited_post = n_neighbor_visited >= 3 + neighbor_visited = neighbor_visited.at[-1].set(~all_visited) + + # Choose a random unvisited neigbor and remove walls between current tile + # and this neighbor and at this neighbor. + rand_neighbor_idx = jax.random.choice( + key, jnp.arange(len(NEIGHBOR_WALL_OFFSETS)), p=~neighbor_visited) + rand_neighbor_pos = neighbor_pos[rand_neighbor_idx] + rand_neighbor_wall_pos = abs_pos + (~all_visited)*neighbor_wall_offsets[rand_neighbor_idx] + remove_wall_pos = jnp.concatenate( + (jnp.expand_dims(rand_neighbor_pos, 0), + jnp.expand_dims(rand_neighbor_wall_pos,0)), 0) + wall_map = wall_map.at[ + remove_wall_pos[:,1], remove_wall_pos[:,0] + ].set(False) + + # Set selected neighbor as visited + visited_map = visited_map.at[ + rand_neighbor_pos[1]//2,rand_neighbor_pos[0]//2 + ].set(True) + + # Pop current tile from stack if all neighbors have been visited + vstack_size -= all_visited_post + + # Push selected neighbor onto stack + vstack = vstack.at[vstack_size].set( + rand_neighbor_pos//2 + ) + vstack_size += ~all_visited + + return (wall_map, visited_map, vstack, vstack_size), None + + # for i in range(3*_h*_w): + max_n_steps = 2*_w*_h + key, *subkeys = jax.random.split(key, max_n_steps+1) + (wall_map, visited_map, vstack, vstack_size), _ = jax.lax.scan( + _scan_step, + (wall_map, visited_map, vstack, vstack_size), + jnp.array(subkeys), + length=max_n_steps + ) + + # Randomize goal position + all_pos_idx = jnp.arange(height*width) + + key, subkey = jax.random.split(key) + goal_mask = ~wall_map.flatten() + goal_pos_idx = jax.random.choice(subkey, all_pos_idx, p=goal_mask) + goal_pos = jnp.array([goal_pos_idx%width, goal_pos_idx//width]) + + # Randomize agent position + key, subkey = jax.random.split(key) + agent_mask = goal_mask.at[goal_pos_idx].set(False) + agent_pos_idx = jax.random.choice(subkey, all_pos_idx, p=agent_mask) + agent_pos = jnp.array([agent_pos_idx%width, agent_pos_idx//width], dtype=jnp.uint32) + + # Randomize agent dir + key, subkey = jax.random.split(key) + agent_dir_idx = jax.random.choice(subkey, 4) + + maze_map = make_maze_map( + self.params, + wall_map, + goal_pos, + agent_pos, + agent_dir_idx, + pad_obs=True) + + state = EnvState( + agent_pos=agent_pos, + agent_dir=DIR_TO_VEC[agent_dir_idx], + agent_dir_idx=agent_dir_idx, + goal_pos=goal_pos, + wall_map=wall_map, + maze_map=maze_map, + time=0, + terminal=False, + ) + + return self.get_obs(state), state + + +class PerfectMazeMedium(PerfectMaze): + def __init__(self, *args, **kwargs): + super().__init__(height=19, width=19, *args, **kwargs) + + +class PerfectMazeExtraLarge(PerfectMaze): + def __init__(self, *args, **kwargs): + super().__init__(height=101, width=101, *args, **kwargs) + + +class Memory(MazeSingleton): + def __init__( + self, + height=17, + width=17, + agent_view_size=7, + see_through_walls=True, + see_agent=False, + normalize_obs=False, + obs_agent_pos=False, + max_episode_steps=250, + singleton_seed=-1): + + # Generate walls + wall_map = [ + "0 0 0 0 0 0 0 0 1 0 1 0 0 0 0", + "0 0 0 0 0 0 0 0 1 0 1 0 0 0 0", + "0 0 0 0 0 0 0 0 1 0 1 0 0 0 0", + "0 0 0 0 0 0 0 0 1 0 1 0 0 0 0", + "0 0 0 0 0 0 0 0 1 0 1 0 0 0 0", + "1 1 1 1 0 0 0 0 1 0 1 0 0 0 0", + "0 0 0 1 1 1 1 1 1 0 1 0 0 0 0", + "0 0 0 0 0 0 0 0 0 0 1 0 0 0 0", + "0 0 0 1 1 1 1 1 1 0 1 0 0 0 0", + "1 1 1 1 0 0 0 0 1 0 1 0 0 0 0", + "0 0 0 0 0 0 0 0 1 0 1 0 0 0 0", + "0 0 0 0 0 0 0 0 1 0 1 0 0 0 0", + "0 0 0 0 0 0 0 0 1 0 1 0 0 0 0", + "0 0 0 0 0 0 0 0 1 0 1 0 0 0 0", + "0 0 0 0 0 0 0 0 1 0 1 0 0 0 0" + ] + + super().__init__( + wall_map=wall_map, + goal_pos=(9,5), + agent_pos=(0,7), + agent_dir_idx=0, + see_agent=see_agent, + normalize_obs=normalize_obs, + obs_agent_pos=obs_agent_pos, + max_episode_steps=max_episode_steps + ) + + self.top_pos = jnp.array([9,5], dtype=jnp.uint32) + self.bottom_pos = jnp.array([9,9], dtype=jnp.uint32) + + def reset_env( + self, + key: chex.PRNGKey, + ) -> Tuple[chex.Array, EnvState]: + params = self.params + height, width = params.height, params.width + + agent_pos = jnp.array([0,7], dtype=jnp.uint32) + agent_dir_idx = 0 + + # Randomly generate a memory location + is_top_goal = jax.random.randint(key, minval=0, maxval=2, shape=(1,), dtype=jnp.uint8) + + clue_pos = jnp.array((0,6), dtype=jnp.uint32) + self.goal_pos = is_top_goal*self.top_pos + (1-is_top_goal)*self.bottom_pos + self.distractor_pos = is_top_goal*self.bottom_pos + (1-is_top_goal)*self.top_pos + + goal_color = is_top_goal*COLOR_TO_INDEX['red'] + (1-is_top_goal)*COLOR_TO_INDEX['green'] + + wall_map = self.wall_map + maze_map = make_maze_map( + self.params, + jnp.array(wall_map, dtype=jnp.bool_), + self.goal_pos, + agent_pos, + agent_dir_idx, + pad_obs=True) + + red_goal = jnp.array([OBJECT_TO_INDEX['goal'], COLOR_TO_INDEX['red'], 0], dtype=jnp.uint8) + green_goal = jnp.array([OBJECT_TO_INDEX['goal'], COLOR_TO_INDEX['green'], 0], dtype=jnp.uint8) + clue = is_top_goal*red_goal + (1-is_top_goal)*green_goal + + padding = params.agent_view_size-1 + wall_map = wall_map.at[clue_pos[1], clue_pos[0]].set(True) + maze_map = maze_map.at[padding+clue_pos[1], padding+clue_pos[0]].set(clue) + + wall_map = wall_map.at[self.top_pos[1], self.top_pos[0]].set(True) + maze_map = maze_map.at[padding+self.top_pos[1], padding+self.top_pos[0]].set(red_goal) + + wall_map = wall_map.at[self.bottom_pos[1], self.bottom_pos[0]].set(True) + maze_map = maze_map.at[padding+self.bottom_pos[1], padding+self.bottom_pos[0]].set(green_goal) + + state = EnvState( + agent_pos=agent_pos, + agent_dir=DIR_TO_VEC[agent_dir_idx], + agent_dir_idx=agent_dir_idx, + goal_pos=self.goal_pos, + wall_map=wall_map, + maze_map=maze_map, + time=0, + terminal=False, + ) + + return self.get_obs(state), state + + def get_distractor_pos(self, state): + goal_x, goal_y = state.goal_pos + is_top_goal = jnp.logical_and( + goal_x == self.top_pos[0], + goal_y == self.top_pos[1] + ) + + return is_top_goal*self.bottom_pos + (1-is_top_goal)*self.top_pos + + + def step_agent(self, key: chex.PRNGKey, state: EnvState, action: int) -> Tuple[EnvState, float]: + next_state, reward = super().step_agent( + key=key, + state=state, + action=action + ) + + fwd_pos = jnp.minimum( + jnp.maximum(state.agent_pos + (action == Actions.forward)*state.agent_dir, 0), + jnp.array(( + self.params.width-1, self.params.height-1), + dtype=jnp.uint32 + )) + + distractor_pos = self.get_distractor_pos(state) + fwd_pos_has_distractor = jnp.logical_and( + fwd_pos[0] == distractor_pos[0], + fwd_pos[1] == distractor_pos[1] + ) + + next_state = next_state.replace( + terminal=jnp.logical_or( + next_state.terminal, + fwd_pos_has_distractor + ) + ) + + return ( + next_state, + reward + ) + + +# ======== Registration ======== +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +register(env_id='Maze-SixteenRooms', entry_point=module_path + ':SixteenRooms') +register(env_id='Maze-SixteenRooms2', entry_point=module_path + ':SixteenRooms2') +register(env_id='Maze-Labyrinth', entry_point=module_path + ':Labyrinth') +register(env_id='Maze-Labyrinth2', entry_point=module_path + ':Labyrinth2') +register(env_id='Maze-LabyrinthFlipped', entry_point=module_path + ':LabyrinthFlipped') +register(env_id='Maze-StandardMaze', entry_point=module_path + ':StandardMaze') +register(env_id='Maze-StandardMaze2', entry_point=module_path + ':StandardMaze2') +register(env_id='Maze-StandardMaze3', entry_point=module_path + ':StandardMaze3') +register(env_id='Maze-SmallCorridor', entry_point=module_path + ':SmallCorridor') +register(env_id='Maze-LargeCorridor', entry_point=module_path + ':LargeCorridor') + +register(env_id='Maze-FourRooms', entry_point=module_path + ':FourRooms') +register(env_id='Maze-Crossing', entry_point=module_path + ':Crossing') +register(env_id='Maze-PerfectMaze', entry_point=module_path + ':PerfectMaze') +register(env_id='Maze-PerfectMazeMedium', entry_point=module_path + ':PerfectMazeMedium') +register(env_id='Maze-PerfectMazeXL', entry_point=module_path + ':PerfectMazeExtraLarge') + +register(env_id='Maze-Memory', entry_point=module_path + ':Memory') \ No newline at end of file diff --git a/src/minimax/envs/maze/maze_ued.py b/src/minimax/envs/maze/maze_ued.py new file mode 100644 index 0000000..1d80e86 --- /dev/null +++ b/src/minimax/envs/maze/maze_ued.py @@ -0,0 +1,425 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from dataclasses import dataclass +from collections import namedtuple, OrderedDict +from functools import partial +from enum import IntEnum + +import numpy as np +import jax +import jax.numpy as jnp +from jax import lax +from typing import Tuple, Optional +import chex +from flax import struct +from flax.core.frozen_dict import FrozenDict + +from .common import EnvInstance, make_maze_map +from minimax.envs import environment, spaces +from minimax.envs.registration import register_ued + + +class SequentialActions(IntEnum): + skip = 0 + wall = 1 + goal = 2 + agent = 3 + + +@struct.dataclass +class EnvState: + encoding: chex.Array + time: int + terminal: bool + + +@struct.dataclass +class EnvParams: + height: int = 15 + width: int = 15 + n_walls: int = 25 + noise_dim: int = 50 + replace_wall_pos: bool = False + fixed_n_wall_steps: bool = False + first_wall_pos_sets_budget: bool = False + use_seq_actions: bool = False + set_agent_dir: bool = False + normalize_obs: bool = False + singleton_seed: int = -1 + + +class UEDMaze(environment.Environment): + def __init__( + self, + height=13, + width=13, + n_walls=25, + noise_dim=16, + replace_wall_pos=False, + fixed_n_wall_steps=False, + first_wall_pos_sets_budget=False, + use_seq_actions=False, + set_agent_dir=False, + normalize_obs=False, + ): + """ + Using the original action space requires ensuring proper handling + of a sequence with trailing dones, e.g. dones: 0 0 0 0 1 1 1 1 1 ... 1. + Advantages and value losses should only be computed where ~dones[0]. + """ + assert not (first_wall_pos_sets_budget and fixed_n_wall_steps), \ + 'Setting first_wall_pos_sets_budget=True requires fixed_n_wall_steps=False.' + + super().__init__() + + self.n_tiles = height*width + # go straight, turn left, turn right, take action + self.action_set = jnp.array(jnp.arange(self.n_tiles)) + + self.params = EnvParams( + height=height, + width=width, + n_walls=n_walls, + noise_dim=noise_dim, + replace_wall_pos=replace_wall_pos, + fixed_n_wall_steps=fixed_n_wall_steps, + first_wall_pos_sets_budget=first_wall_pos_sets_budget, + use_seq_actions=False, + set_agent_dir=set_agent_dir, + normalize_obs=normalize_obs, + ) + + @staticmethod + def align_kwargs(kwargs, other_kwargs): + kwargs.update(dict( + height=other_kwargs['height'], + width=other_kwargs['width'], + )) + + return kwargs + + def _add_noise_to_obs(self, rng, obs): + if self.params.noise_dim > 0: + noise = jax.random.uniform(rng, (self.params.noise_dim,)) + obs.update(dict(noise=noise)) + + return obs + + def reset_env( + self, + key: chex.PRNGKey): + """ + Prepares the environment state for a new design + from a blank slate. + """ + params = self.params + noise_rng, dir_rng = jax.random.split(key) + encoding = jnp.zeros((self._get_encoding_dim(),), dtype=jnp.uint32) + + if not params.set_agent_dir: + rand_dir = jax.random.randint( + dir_rng, (), minval=0, maxval=4) # deterministic + tile_scale_dir = jnp.ceil( + (rand_dir/4)*self.n_tiles).astype(jnp.uint32) + encoding = encoding.at[-1].set(tile_scale_dir) + + state = EnvState( + encoding=encoding, + time=0, + terminal=False, + ) + + obs = self._add_noise_to_obs( + noise_rng, + self.get_obs(state) + ) + + return obs, state + + def step_env( + self, + key: chex.PRNGKey, + state: EnvState, + action: int, + ) -> Tuple[chex.Array, EnvState, float, bool, dict]: + """ + Take a design step. + action: A pos as an int from 0 to (height*width)-1 + """ + params = self.params + + collision_rng, noise_rng = jax.random.split(key) + + # Sample a random free tile in case of a collision + dist_values = jnp.logical_and( # True if position taken + jnp.ones(params.n_walls + 2), + jnp.arange(params.n_walls + 2)+1 > state.time + ) + + # Get zero-indexed last wall time step + if params.fixed_n_wall_steps: + max_n_walls = params.n_walls + encoding_pos = state.encoding[:params.n_walls+2] + last_wall_step_idx = max_n_walls - 1 + else: + max_n_walls = jnp.round( + params.n_walls*state.encoding[0]/self.n_tiles).astype(jnp.uint32) + + if self.params.first_wall_pos_sets_budget: + encoding_pos = state.encoding[:params.n_walls+2] + last_wall_step_idx = jnp.maximum(max_n_walls, 1) - 1 + else: + encoding_pos = state.encoding[1:params.n_walls+3] + last_wall_step_idx = max_n_walls + + pos_dist = jnp.ones(self.n_tiles).at[ + jnp.flip(encoding_pos)].set(jnp.flip(dist_values)) + all_pos = jnp.arange(self.n_tiles, dtype=jnp.uint32) + + # Only mark collision if replace_wall_pos=False OR the agent is placed over the goal + goal_step_idx = last_wall_step_idx + 1 + agent_step_idx = last_wall_step_idx + 2 + + # Track whether it is the last time step + next_state = state.replace(time=state.time + 1) + done = self.is_terminal(next_state) + + # Always place agent idx in last enc position. + is_agent_dir_step = jnp.logical_and( + params.set_agent_dir, + done + ) + + collision = jnp.logical_and( + pos_dist[action] < 1, + jnp.logical_or( + not params.replace_wall_pos, + jnp.logical_and( # agent pos cannot override goal + jnp.equal(state.time, agent_step_idx), + jnp.equal(state.encoding[goal_step_idx], action) + ) + ) + ) + collision = (collision * (1-is_agent_dir_step)).astype(jnp.uint32) + + action = (1-collision)*action + \ + collision*jax.random.choice(collision_rng, + all_pos, replace=False, p=pos_dist) + + enc_idx = (1-is_agent_dir_step)*state.time + is_agent_dir_step*(-1) + encoding = state.encoding.at[enc_idx].set(action) + + next_state = next_state.replace( + encoding=encoding, + terminal=done + ) + reward = 0 + + obs = self._add_noise_to_obs(noise_rng, self.get_obs(next_state)) + + return ( + lax.stop_gradient(obs), + lax.stop_gradient(next_state), + reward, + done, + {}, + ) + + def get_env_instance( + self, + key: chex.PRNGKey, + state: EnvState + ) -> chex.Array: + """ + Converts internal encoding to an instance encoding that + can be interpreted by the `set_to_instance` method + the paired Environment class. + """ + params = self.params + h = params.height + w = params.width + enc = state.encoding + + # === Extract agent_dir, agent_pos, and goal_pos === + # Num walls placed currently + if params.fixed_n_wall_steps: + n_walls = params.n_walls + enc_len = self._get_encoding_dim() + wall_pos_idx = jnp.flip(enc[:params.n_walls]) + agent_pos_idx = enc_len-2 # Enc is full length + goal_pos_idx = enc_len-3 + else: + n_walls = jnp.round( + params.n_walls*enc[0]/self.n_tiles + ).astype(jnp.uint32) + if params.first_wall_pos_sets_budget: + # So 0-padding does not override pos=0 + wall_pos_idx = jnp.flip(enc[:params.n_walls]) + enc_len = n_walls + 2 # [wall_pos] + len((goal, agent)) + else: + wall_pos_idx = jnp.flip(enc[1:params.n_walls+1]) + # [wall_pos] + len((n_walls, goal, agent)) + enc_len = n_walls + 3 + # Positions are relative to n_walls when n_walls is variable. + agent_pos_idx = enc_len-1 + goal_pos_idx = enc_len-2 + + # Get agent + goal info (set agent/goal pos 1-step out of range if they are not yet placed) + goal_placed = state.time > jnp.array([goal_pos_idx], dtype=jnp.uint32) + goal_pos = \ + goal_placed*jnp.array([enc[goal_pos_idx] % w, enc[goal_pos_idx]//w], dtype=jnp.uint32) \ + + (~goal_placed)*jnp.array([w, h], dtype=jnp.uint32) + + agent_placed = state.time > jnp.array( + [agent_pos_idx], dtype=jnp.uint32) + agent_pos = \ + agent_placed*jnp.array([enc[agent_pos_idx] % w, enc[agent_pos_idx]//w], dtype=jnp.uint32) \ + + (~agent_placed)*jnp.array([w, h], dtype=jnp.uint32) + + agent_dir_idx = jnp.floor((4*enc[-1]/self.n_tiles)).astype(jnp.uint8) + + # Make wall map + wall_start_time = jnp.logical_and( # 1 if explicitly predict # blocks, else 0 + not params.fixed_n_wall_steps, + not params.first_wall_pos_sets_budget + ).astype(jnp.uint32) + wall_map = jnp.zeros(h*w, dtype=jnp.bool_) + wall_values = jnp.arange( + params.n_walls) + wall_start_time < jnp.minimum(state.time, n_walls + wall_start_time) + wall_values = jnp.flip(wall_values) + wall_map = wall_map.at[wall_pos_idx].set(wall_values) + + # Zero out walls where agent and goal reside + agent_mask = agent_placed * \ + (~(jnp.arange(h*w) == + state.encoding[agent_pos_idx])) + ~agent_placed*wall_map + goal_mask = goal_placed * \ + (~(jnp.arange(h*w) == + state.encoding[goal_pos_idx])) + ~goal_placed*wall_map + wall_map = wall_map*agent_mask*goal_mask + wall_map = wall_map.reshape(h, w) + + return EnvInstance( + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + goal_pos=goal_pos, + wall_map=wall_map + ) + + def is_terminal(self, state: EnvState) -> bool: + done_steps = state.time >= self.max_episode_steps() + return jnp.logical_or(done_steps, state.terminal) + + def _get_post_terminal_obs(self, state: EnvState): + dtype = jnp.float32 if self.params.normalize_obs else jnp.uint8 + image = jnp.zeros(( + self.params.height+2, self.params.width+2, 3), dtype=dtype + ) + + return OrderedDict(dict( + image=image, + time=state.time, + noise=jnp.zeros(self.params.noise_dim, dtype=jnp.float32), + )) + + def get_obs(self, state: EnvState): + instance = self.get_env_instance(jax.random.PRNGKey(0), state) + + image = make_maze_map( + self.params, + instance.wall_map, + instance.goal_pos, + instance.agent_pos, + instance.agent_dir_idx, + pad_obs=False + ) + + if self.params.normalize_obs: + image = image/10.0 + + return OrderedDict(dict( + image=image, + time=state.time, + )) + + @property + def default_params(self): + return EnvParams() + + @property + def name(self) -> str: + """Environment name.""" + return "UEDMaze" + + @property + def num_actions(self) -> int: + """Number of actions possible in environment.""" + return len(self.action_set) + + def action_space(self) -> spaces.Discrete: + """Action space of the environment.""" + params = self.params + return spaces.Discrete( + params.height*params.width, + dtype=jnp.uint32 + ) + + def observation_space(self) -> spaces.Dict: + """Observation space of the environment.""" + params = self.params + max_episode_steps = self.max_episode_steps() + spaces_dict = { + 'image': spaces.Box(0, 255, (params.height+2, params.width+2, 3)), + 'time': spaces.Discrete(max_episode_steps), + } + if self.params.noise_dim > 0: + spaces_dict.update({ + 'noise': spaces.Box(0, 1, (self.params.noise_dim,)) + }) + return spaces.Dict(spaces_dict) + + def state_space(self) -> spaces.Dict: + """State space of the environment.""" + params = self.params + encoding_dim = self._get_encoding_dim() + max_episode_steps = self.max_episode_steps() + h = params.height + w = params.width + return spaces.Dict({ + 'encoding': spaces.Box(0, 255, (encoding_dim,)), + 'time': spaces.Discrete(max_episode_steps), + "terminal": spaces.Discrete(2), + }) + + def _get_encoding_dim(self) -> int: + encoding_dim = self.max_episode_steps() + if not self.params.set_agent_dir: + encoding_dim += 1 # max steps is 1 less than full encoding dim + + return encoding_dim + + def max_episode_steps(self) -> int: + if self.params.fixed_n_wall_steps \ + or self.params.first_wall_pos_sets_budget: + max_episode_steps = self.params.n_walls + 2 + else: + max_episode_steps = self.params.n_walls + 3 + + if self.params.set_agent_dir: + max_episode_steps += 1 + + return max_episode_steps + + +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +register_ued(env_id='Maze', entry_point=module_path + ':UEDMaze') diff --git a/src/minimax/envs/overcooked_proc/__init__.py b/src/minimax/envs/overcooked_proc/__init__.py new file mode 100644 index 0000000..43e30c5 --- /dev/null +++ b/src/minimax/envs/overcooked_proc/__init__.py @@ -0,0 +1,14 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .overcooked import Overcooked +from .overcooked_ued import UEDOvercooked +from .overcooked_ood import * + +from .overcooked_comparators import * +from .overcooked_mutators import * \ No newline at end of file diff --git a/src/minimax/envs/overcooked_proc/common.py b/src/minimax/envs/overcooked_proc/common.py new file mode 100644 index 0000000..973dcb0 --- /dev/null +++ b/src/minimax/envs/overcooked_proc/common.py @@ -0,0 +1,207 @@ +# Edited from JaxMarl: https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/overcooked + +from flax import struct +import chex +import numpy as np +import jax.numpy as jnp +import jax + + +OBJECT_TO_INDEX = { + "unseen": 0, + "empty": 1, + "wall": 2, + "onion": 3, + "onion_pile": 4, + "plate": 5, + "plate_pile": 6, + "goal": 7, + "pot": 8, + "dish": 9, + "agent": 10, +} + + +COLORS = { + 'red': np.array([255, 0, 0]), + 'green': np.array([0, 255, 0]), + 'blue': np.array([0, 0, 255]), + 'purple': np.array([112, 39, 195]), + 'yellow': np.array([255, 255, 0]), + 'grey': np.array([100, 100, 100]), + 'white': np.array([255, 255, 255]), + 'black': np.array([25, 25, 25]), + 'orange': np.array([230, 180, 0]), +} + + +COLOR_TO_INDEX = { + 'red': 0, + 'green': 1, + 'blue': 2, + 'purple': 3, + 'yellow': 4, + 'grey': 5, + 'white': 6, + 'black': 7, + 'orange': 8, +} + + +OBJECT_INDEX_TO_VEC = jnp.array([ + jnp.array([OBJECT_TO_INDEX['unseen'], 0, 0], dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['empty'], 0, 0], dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['wall'], COLOR_TO_INDEX['grey'], 0], + dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['onion'], + COLOR_TO_INDEX["yellow"], 0], dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['onion_pile'], + COLOR_TO_INDEX["yellow"], 0], dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['plate'], + COLOR_TO_INDEX["white"], 0], dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['plate_pile'], + COLOR_TO_INDEX["white"], 0], dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['goal'], COLOR_TO_INDEX['green'], 0], + dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['pot'], COLOR_TO_INDEX['black'], 0], + dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['dish'], COLOR_TO_INDEX["white"], 0], + dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['agent'], COLOR_TO_INDEX['red'], 0], + dtype=jnp.uint8), # Default color and direction +]) + + +# Map of agent direction indices to vectors +DIR_TO_VEC = jnp.array([ + # Pointing right (positive X) + # (1, 0), # right + # (0, 1), # down + # (-1, 0), # left + # (0, -1), # up + (0, -1), # NORTH + (0, 1), # SOUTH + (1, 0), # EAST + (-1, 0), # WEST +], dtype=jnp.int8) + + +@struct.dataclass +class EnvInstance: + agent_pos: chex.Array + agent_dir_idx: chex.Array + agent_inv: chex.Array + goal_pos: chex.Array + pot_pos: chex.Array + onion_pile_pos: chex.Array + plate_pile_pos: chex.Array + wall_map: chex.Array + + +def make_overcooked_map( + wall_map, + goal_pos, + agent_pos, + agent_dir_idx, + plate_pile_pos, + onion_pile_pos, + pot_pos, + pot_status, + onion_pos, + plate_pos, + dish_pos, + pad_obs=True, + num_agents=2, + agent_view_size=5): + # Expand maze map to H x W x C + empty = jnp.array([OBJECT_TO_INDEX['empty'], 0, 0], dtype=jnp.uint8) + wall = jnp.array( + [OBJECT_TO_INDEX['wall'], COLOR_TO_INDEX['grey'], 0], dtype=jnp.uint8) + maze_map = jnp.array(jnp.expand_dims(wall_map, -1), dtype=jnp.uint8) + maze_map = jnp.where(maze_map > 0, wall, empty) + + # Add agents + def _get_agent_updates(agent_dir_idx, agent_pos, agent_idx): + agent = jnp.array([OBJECT_TO_INDEX['agent'], COLOR_TO_INDEX['red'] + + agent_idx*2, agent_dir_idx], dtype=jnp.uint8) + agent_x, agent_y = agent_pos + return agent_x, agent_y, agent + + agent_x_vec, agent_y_vec, agent_vec = jax.vmap(_get_agent_updates, in_axes=( + 0, 0, 0))(agent_dir_idx, agent_pos, jnp.arange(num_agents)) + maze_map = maze_map.at[agent_y_vec, agent_x_vec, :].set(agent_vec) + + # Add goals + goal = jnp.array( + [OBJECT_TO_INDEX['goal'], COLOR_TO_INDEX['green'], 0], dtype=jnp.uint8) + + def set_based_on_position_mask(maze_map, pos_mask, obj): + pos_expanded = jnp.repeat( + jnp.expand_dims(pos_mask, axis=-1), 3, axis=-1) + obj_maze_map = pos_expanded * jnp.tile(obj, (*pos_mask.shape, 1)) + maze_map = maze_map * \ + jnp.logical_not(pos_expanded) + obj_maze_map * pos_expanded + return maze_map + + maze_map = set_based_on_position_mask(maze_map, goal_pos, goal) + + # Add onions + onion_pile = jnp.array( + [OBJECT_TO_INDEX['onion_pile'], COLOR_TO_INDEX["yellow"], 0], dtype=jnp.uint8) + maze_map = set_based_on_position_mask(maze_map, onion_pile_pos, onion_pile) + + # Add plates + plate_pile = jnp.array( + [OBJECT_TO_INDEX['plate_pile'], COLOR_TO_INDEX["white"], 0], dtype=jnp.uint8) + maze_map = set_based_on_position_mask(maze_map, plate_pile_pos, plate_pile) + + pot_obj = jnp.array( + [OBJECT_TO_INDEX['pot'], COLOR_TO_INDEX["black"], 0], dtype=jnp.uint8) + pot_status = pot_status.reshape(pot_pos.shape) + pot_status = jnp.concatenate((jnp.zeros( + (*pot_status.shape, 2), dtype=jnp.uint8), pot_status[:, :, jnp.newaxis]), axis=-1) + pos_expanded = jnp.repeat(jnp.expand_dims(pot_pos, axis=-1), 3, axis=-1) + obj_maze_map = pos_expanded * \ + jnp.tile(pot_obj, (*pot_pos.shape, 1)) + pot_status + maze_map = maze_map * \ + jnp.logical_not(pos_expanded) + obj_maze_map * pos_expanded + + onion = jnp.array( + [OBJECT_TO_INDEX['onion'], COLOR_TO_INDEX["yellow"], 0], dtype=jnp.uint8) + maze_map = set_based_on_position_mask(maze_map, onion_pos, onion) + + plate = jnp.array( + [OBJECT_TO_INDEX['plate'], COLOR_TO_INDEX["white"], 0], dtype=jnp.uint8) + maze_map = set_based_on_position_mask(maze_map, plate_pos, plate) + + dish = jnp.array( + [OBJECT_TO_INDEX['dish'], COLOR_TO_INDEX["white"], 0], dtype=jnp.uint8) + maze_map = set_based_on_position_mask(maze_map, dish_pos, dish) + + # Add observation padding + if pad_obs: + padding = agent_view_size-1 + else: + padding = 1 + + maze_map_padded = jnp.tile(wall.reshape( + (1, 1, *empty.shape)), (maze_map.shape[0]+2*padding, maze_map.shape[1]+2*padding, 1)) + maze_map_padded = maze_map_padded.at[ + padding:-padding, padding:-padding, :].set(maze_map) + + # Add surrounding walls + wall_start = padding-1 # start index for walls + wall_end_y = maze_map_padded.shape[0] - wall_start - 1 + wall_end_x = maze_map_padded.shape[1] - wall_start - 1 + maze_map_padded = maze_map_padded.at[wall_start, + wall_start:wall_end_x+1, :].set(wall) # top + maze_map_padded = maze_map_padded.at[wall_end_y, + wall_start:wall_end_x+1, :].set(wall) # bottom + # left + maze_map_padded = maze_map_padded.at[wall_start:wall_end_y + + 1, wall_start, :].set(wall) + # right + maze_map_padded = maze_map_padded.at[wall_start:wall_end_y + + 1, wall_end_x, :].set(wall) + + return maze_map_padded diff --git a/src/minimax/envs/overcooked_proc/interactive.py b/src/minimax/envs/overcooked_proc/interactive.py new file mode 100644 index 0000000..1094f90 --- /dev/null +++ b/src/minimax/envs/overcooked_proc/interactive.py @@ -0,0 +1,290 @@ +# Edited from JaxMarl: https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/overcooked + +import argparse +from functools import partial + +import jax +import jax.numpy as jnp + + +from jaxmarl.environments.overcooked.overcooked import Overcooked +from jaxmarl.environments.overcooked.layouts import overcooked_layouts as layouts + + +def redraw(state, obs, extras): + extras['viz'].render(extras['agent_view_size'], state, highlight=False) + + +def reset(key, env, extras): + key, subkey = jax.random.split(extras['rng']) + obs, state = extras['jit_reset'](subkey) + + extras['rng'] = key + extras['obs'] = obs + extras['state'] = state + + redraw(state, obs, extras) + + +def step(env, action, extras): + key, subkey = jax.random.split(extras['rng']) + + actions = {"agent_0": jnp.array(action), "agent_1": jnp.array(action)} + print("Actions : ", actions) + obs, state, reward, done, info = jax.jit( + env.step_env)(subkey, extras['state'], actions) + extras['obs'] = obs + extras['state'] = state + print( + f"t={state.time}: reward={reward['agent_0']}, agent_dir={state.agent_dir_idx}, agent_inv={state.agent_inv}, done = {done['__all__']}") + + if extras["debug"]: + layers = [f"player_{i}_loc" for i in range(2)] + layers.extend( + [f"player_{i // 4}_orientation_{i % 4}" for i in range(8)]) + layers.extend([ + "pot_loc", + "counter_loc", + "onion_disp_loc", + "tomato_disp_loc", + "plate_disp_loc", + "serve_loc", + "onions_in_pot", + "tomatoes_in_pot", + "onions_in_soup", + "tomatoes_in_soup", + "soup_cook_time_remaining", + "soup_done", + "plates", + "onions", + "tomatoes", + "urgency" + ]) + print("obs_shape: ", obs["agent_0"].shape) + print("OBS: \n", obs["agent_0"]) + debug_obs = jnp.transpose(obs["agent_0"], (2, 0, 1)) + for i, layer in enumerate(layers): + print(layer) + print(debug_obs[i]) + # print(f"agent obs =\n {obs}") + + if done["__all__"] or (jnp.array([action, action]) == Actions.done).any(): + key, subkey = jax.random.split(subkey) + reset(subkey, env, extras) + else: + redraw(state, obs, extras) + + extras['rng'] = key + + +def key_handler(env, extras, event): + print('pressed', event.key) + + if event.key == 'escape': + window.close() + return + + if event.key == 'backspace': + extras['jit_reset']((env, extras)) + return + + if event.key == 'left': + step(env, Actions.left, extras) + return + if event.key == 'right': + step(env, Actions.right, extras) + return + if event.key == 'up': + step(env, Actions.forward, extras) + return + + # Spacebar + if event.key == ' ': + step(env, Actions.toggle, extras) + return + if event.key == '[': + step(env, Actions.pickup, extras) + return + if event.key == ']': + step(env, Actions.drop, extras) + return + + if event.key == 'enter': + step(env, Actions.done, extras) + return + + +def key_handler_overcooked(env, extras, event): + print('pressed', event.key) + + if event.key == 'escape': + window.close() + return + if event.key == 'backspace': + extras['jit_reset']((env, extras)) + return + + if event.key == 'left': + step(env, Actions.left, extras) + return + if event.key == 'right': + step(env, Actions.right, extras) + return + if event.key == 'up': + # step(env, Actions.forward, extras) + step(env, Actions.up, extras) + return + if event.key == 'down': + step(env, Actions.down, extras) + return + + # Spacebar + if event.key == ' ': + step(env, Actions.interact, extras) + return + if event.key == 'tab': + step(env, Actions.stay, extras) + return + if event.key == 'enter': + step(env, Actions.done, extras) + return + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--env", + type=str, + help="Environment name", + default="Overcooked" + ) + parser.add_argument( + "--layout", + type=str, + help="Overcooked layout", + default="cramped_room" + ) + parser.add_argument( + '--random_reset', + default=False, + help="Reset to random state", + action='store_true' + ) + parser.add_argument( + "--seed", + type=int, + help="random seed to generate the environment with", + default=0 + ) + parser.add_argument( + '--render_agent_view', + default=False, + help="draw the agent sees (partially observable view)", + action='store_true' + ) + # parser.add_argument( + # '--height', + # default=13, + # type=int, + # help="height", + # ) + # parser.add_argument( + # '--width', + # default=13, + # type=int, + # help="width", + # ) + # parser.add_argument( + # '--n_walls', + # default=50, + # type=int, + # help="Number of walls", + # ) + # parser.add_argument( + # '--agent_view_size', + # default=5, + # type=int, + # help="Number of walls", + # ) + parser.add_argument( + '--debug', + default=False, + help="Debug mode", + action='store_true' + ) + args = parser.parse_args() + + # if args.env == "Maze": + # env = Maze( + # height=13, + # width=13, + # n_walls=25, + # see_agent=True, + # ) + # from jaxmarl.gridworld.grid_viz import GridVisualizer as Visualizer + # from jaxmarl.gridworld.maze import Actions + # + # params = env.params + # + # elif args.env == "MAMaze": + # env = MAMaze( + # height=13, + # width=13, + # n_walls=25, + # see_agent=True, + # n_agents=2 + # ) + # from jaxmarl.gridworld.grid_viz import GridVisualizer as Visualizer + # from jaxmarl.gridworld.maze import Actions + # + # params = env.params + + if args.env == "Overcooked": + if len(args.layout) > 0: + layout = layouts[args.layout] + env = Overcooked( + layout=layout, + random_reset=args.random_reset + ) + else: + print("You must provide a layout.") + from jaxmarl.viz.overcooked_visualizer import OvercookedVisualizer as Visualizer + from jaxmarl.environments.overcooked.overcooked import Actions + + viz = Visualizer() + obs_viz = None + obs_viz2 = None + if args.render_agent_view: + obs_viz = Visualizer() + if args.env == "MAMaze" or "Overcooked": + obs_viz2 = Visualizer() + + with jax.disable_jit(False): + jit_reset = jax.jit(env.reset) + # jit_reset = env.reset_env + key = jax.random.PRNGKey(args.seed) + key, subkey = jax.random.split(key) + o0, s0 = jit_reset(subkey) + viz.render(env.agent_view_size, s0, highlight=False) + + key, subkey = jax.random.split(key) + extras = { + 'rng': subkey, + 'state': s0, + 'obs': o0, + 'viz': viz, + 'obs_viz': obs_viz, + 'obs_viz2': obs_viz2, + 'jit_reset': jit_reset, + 'agent_view_size': env.agent_view_size, + 'env': args.env, + 'debug': args.debug + } + + if args.env == "Overcooked": + viz.window.reg_key_handler( + partial(key_handler_overcooked, env, extras)) + viz.show(block=True) + else: + viz.window.reg_key_handler(partial(key_handler, env, extras)) + viz.show(block=True) diff --git a/src/minimax/envs/overcooked_proc/layouts.py b/src/minimax/envs/overcooked_proc/layouts.py new file mode 100644 index 0000000..78ef5b0 --- /dev/null +++ b/src/minimax/envs/overcooked_proc/layouts.py @@ -0,0 +1,556 @@ +# Edited from JaxMarl: https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/overcooked + + +import json +import jax.numpy as jnp +from flax.core.frozen_dict import FrozenDict + +cramped_room = { + "height": 4, + "width": 5, + "wall_idx": jnp.array([0, 1, 2, 3, 4, + 5, 9, + 10, 14, + 15, 16, 17, 18, 19]), + "agent_idx": jnp.array([6, 8]), + "goal_idx": jnp.array([18]), + "plate_pile_idx": jnp.array([16]), + "onion_pile_idx": jnp.array([5, 9]), + "pot_idx": jnp.array([2]) +} +asymm_advantages = { + "height": 5, + "width": 9, + "wall_idx": jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 11, 12, 13, 14, 15, 17, + 18, 22, 26, + 27, 31, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44]), + "agent_idx": jnp.array([29, 32]), + "goal_idx": jnp.array([12, 17]), + "plate_pile_idx": jnp.array([39, 41]), + "onion_pile_idx": jnp.array([9, 14]), + "pot_idx": jnp.array([22, 31]) +} +coord_ring = { + "height": 5, + "width": 5, + "wall_idx": jnp.array([0, 1, 2, 3, 4, + 5, 9, + 10, 12, 14, + 15, 19, + 20, 21, 22, 23, 24]), + "agent_idx": jnp.array([7, 11]), + "goal_idx": jnp.array([22]), + "plate_pile_idx": jnp.array([10]), + "onion_pile_idx": jnp.array([15, 21]), + "pot_idx": jnp.array([3, 9]) +} +forced_coord = { + "height": 5, + "width": 5, + "wall_idx": jnp.array([0, 1, 2, 3, 4, + 5, 7, 9, + 10, 12, 14, + 15, 17, 19, + 20, 21, 22, 23, 24]), + "agent_idx": jnp.array([11, 8]), + "goal_idx": jnp.array([23]), + "onion_pile_idx": jnp.array([5, 10]), + "plate_pile_idx": jnp.array([15]), + "pot_idx": jnp.array([3, 9]) +} + +# Example of layout provided as a grid +counter_circuit_grid = """ +WWWPPWWW +W A W +B WWWW X +W AW +WWWOOWWW +""" + +asymm_advantages_6_9 = """ +WWWWWWWWW +O WXWOW X +W P A W +WA P W +WWWBWBWWW +WWWWWWWWW +""" + +counter_circuit_6_9 = """ +WWWPPWWWW +W A WW +B WWWW XW +W AWW +WWWOOWWWW +WWWWWWWWW +""" + +forced_coord_6_9 = """ +WWWPWWWWW +OAWAPWWWW +O W WWWWW +B W WWWWW +WWWXWWWWW +WWWWWWWWW +""" + +cramped_room_6_9 = """ +WWPWWWWWW +OAA OWWWW +W WWWWW +WBWXWWWWW +WWWWWWWWW +WWWWWWWWW +""" + +coord_ring_6_9 = """ +WWWPWWWWW +WA APWWWW +B W WWWWW +O WWWWW +WOXWWWWWW +WWWWWWWWW +""" + +quad_6_9 = """ +WWWWWWWWW +WWPA WWW +WWB AWWW +WWWOXOWWW +WWWWWWWWW +WWWWWWWWW +""" + +quad_6_9_1 = """ +WWWPWWWWW +WWBA WWW +WWO AWWW +WWWXOWWWW +WWWWWWWWW +WWWWWWWWW +""" + +quad_6_9_2 = """ +WWWBPWWWW +WWOA WWW +WWX AWWW +WWWOWWWWW +WWWWWWWWW +WWWWWWWWW +""" + +quad_6_9_3 = """ +WWWOBPWWW +WWXA WWW +WWO AWWW +WWWWWWWWW +WWWWWWWWW +WWWWWWWWW +""" + +quad_6_9_4 = """ +WWWXOBWWW +WWOA PWW +WWW AWWW +WWWWWWWWW +WWWWWWWWW +WWWWWWWWW +""" + +quad_6_9_5 = """ +WWWOXOWWW +WWWA BWW +WWW APWW +WWWWWWWWW +WWWWWWWWW +WWWWWWWWW +""" + +quad_6_9_6 = """ +WWWWOXWWW +WWWA OWW +WWW ABWW +WWWWWPWWW +WWWWWWWWW +WWWWWWWWW +""" + +quad_6_9_7 = """ +WWWWWOWWW +WWWA XWW +WWW AOWW +WWWWPBWWW +WWWWWWWWW +WWWWWWWWW +""" + +quad_6_9_8 = """ +WWWWWWWWW +WWWA OWW +WWW AXWW +WWWPBOWWW +WWWWWWWWW +WWWWWWWWW +""" + +quad_6_9_9 = """ +WWWWWWWWW +WWWA WWW +WWP AOWW +WWWBOXWWW +WWWWWWWWW +WWWWWWWWW +""" + +quad_6_9T = """ +WWWOXOWWW +WWW APWW +WWWA BWW +WWWWWWWWW +WWWWWWWWW +WWWWWWWWW +""" + +quad_6_9M = """ +WWWOXOWWW +WWB APWW +WWPA BWW +WWWOXOWWW +WWWWWWWWW +WWWWWWWWW +""" + +asymm_advantages_10_15 = """ +WWWWWWWWWWWWWWW +O WXWOW XWWWWWW +W P A WWWWWWW +WA P WWWWWWW +WWWBWBWWWWWWWWW +WWWWWWWWWWWWWWW +WWWWWWWWWWWWWWW +WWWWWWWWWWWWWWW +WWWWWWWWWWWWWWW +WWWWWWWWWWWWWWW +""" + +asymm_advantages_B_10_15 = """ +WWWWWWWWWWWWWWW +O WXWOW XWWWWWW +W P A WWWWWWW +W W WWWWWWW +WA P WWWWWWW +WWWBWBWWWWWWWWW +WWWWWWWWWWWWWWW +WWWWWWWWWWWWWWW +WWWWWWWWWWWWWWW +WWWWWWWWWWWWWWW +""" + +asymm_advantages_M_10_15 = """ +WWWWWWWWWWWWWWW +O WXWOW XWWWW +W P A WWWWW +WA W WWWWW +W P WWWWW +WWWBWWWBWWWWWWW +WWWWWWWWWWWWWWW +WWWWWWWWWWWWWWW +WWWWWWWWWWWWWWW +WWWWWWWWWWWWWWW +""" + +asymm_advantages_L_10_15 = """ +WWWWWWWWWWWWWWW +O WXWOW XWW +W P A WWW +WA W WWW +W W WWW +W P WWW +WWWBWWWWWBWWWWW +WWWWWWWWWWWWWWW +WWWWWWWWWWWWWWW +WWWWWWWWWWWWWWW +""" + +asymm_advantages_m = """ +WWWWWWWWW +O WOWXW X +W P A W +WA P W +WWWBWBWWW +""" + +asymm_advantages_m1 = """ +WWWWWWWWW +X WXWOW O +W P A W +WA P W +WWWBWBWWW +""" + +asymm_advantages_m2 = """ +WWWWWWWWW +O WXWOW X +WA P A W +W P W +WWWBWBWWW +""" + +asymm_advantages_m3 = """ +WWWWWWWWW +O WXWOW X +W P AW +WA P W +WWWBWBWWW +""" + +asymm_advantages_m4 = """ +WWWWWWWWW +O WXWOW X +W P A W +WA B W +WWWPWBWWW +""" + +asymm_advantages_m5 = """ +WWWWWWWWW +O WXWOW X +W B A W +WA P W +WWWPWBWWW +""" + +asymm_advantages_m6 = """ +WWWWWWWWW +O WXWOW X +W P A W +WA P W +WWBWWBWWW +""" + +asymm_advantages_m7 = """ +WWWWWWWWW +O WXWOW X +W P A W +WA P W +WWBWWWBWW +""" + +asymm_advantages_m8 = """ +WWWWWWWWW +O WXWOW X +W P A W +WA P W +WBWWWWBWW +""" + + +def layout_grid_to_onehot_dict(grid): + """Assumes `grid` is string representation of the layout, with 1 line per row, and the following symbols: + W: wall + A: agent + X: goal + B: plate (bowl) pile + O: onion pile + P: pot location + ' ' (space) : empty cell + """ + + rows = grid.split('\n') + + if len(rows[0]) == 0: + rows = rows[1:] + if len(rows[-1]) == 0: + rows = rows[:-1] + + keys = ["wall_idx", "agent_idx", "goal_idx", + "plate_pile_idx", "onion_pile_idx", "pot_idx", "empty_table_idx"] + symbol_to_key = {"W": "wall_idx", + "A": "agent_idx", + "X": "goal_idx", + "B": "plate_pile_idx", + "O": "onion_pile_idx", + "P": "pot_idx"} + + layout_dict = {key: [] for key in keys} + layout_dict["height"] = len(rows) + layout_dict["width"] = len(rows[0]) + width = len(rows[0]) + + for i, row in enumerate(rows): + for j, obj in enumerate(row): + idx = width * i + j + # if obj in symbol_to_key.keys(): + # # Add object + # layout_dict[symbol_to_key[obj]].append(idx) + + if obj == "A": + # Agent + layout_dict["agent_idx"].append(idx) + + if obj == "X": + # Goal + layout_dict["goal_idx"].append(1) + else: + layout_dict["goal_idx"].append(0) + + if obj == "B": + # Plate pile + layout_dict["plate_pile_idx"].append(1) + else: + layout_dict["plate_pile_idx"].append(0) + + if obj == "O": + # Onion pile + layout_dict["onion_pile_idx"].append(1) + else: + layout_dict["onion_pile_idx"].append(0) + + if obj == "P": + # Pot location + layout_dict["pot_idx"].append(1) + else: + layout_dict["pot_idx"].append(0) + + if obj in ["X", "B", "O", "P", "W"]: + # These objects are also walls technically + layout_dict["wall_idx"].append(1) + else: + layout_dict["wall_idx"].append(0) + + if obj == "W": + # Goal + layout_dict["empty_table_idx"].append(1) + else: + layout_dict["empty_table_idx"].append(0) + # elif obj == " ": + # # Empty cell + # continue + + for key in layout_dict.keys(): + # Transform lists to arrays + layout_dict[key] = jnp.array(layout_dict[key], dtype=jnp.uint8) + + return FrozenDict(layout_dict) + + +def layout_grid_to_dict(grid): + """Assumes `grid` is string representation of the layout, with 1 line per row, and the following symbols: + W: wall + A: agent + X: goal + B: plate (bowl) pile + O: onion pile + P: pot location + ' ' (space) : empty cell + """ + + rows = grid.split('\n') + + if len(rows[0]) == 0: + rows = rows[1:] + if len(rows[-1]) == 0: + rows = rows[:-1] + + keys = ["wall_idx", "agent_idx", "goal_idx", + "plate_pile_idx", "onion_pile_idx", "pot_idx"] + symbol_to_key = {"W": "wall_idx", + "A": "agent_idx", + "X": "goal_idx", + "B": "plate_pile_idx", + "O": "onion_pile_idx", + "P": "pot_idx"} + + layout_dict = {key: [] for key in keys} + layout_dict["height"] = len(rows) + layout_dict["width"] = len(rows[0]) + width = len(rows[0]) + + for i, row in enumerate(rows): + for j, obj in enumerate(row): + idx = width * i + j + if obj in symbol_to_key.keys(): + # Add object + layout_dict[symbol_to_key[obj]].append(idx) + if obj in ["X", "B", "O", "P"]: + # These objects are also walls technically + layout_dict["wall_idx"].append(idx) + elif obj == " ": + # Empty cell + continue + + for key in symbol_to_key.values(): + # Transform lists to arrays + layout_dict[key] = jnp.array(layout_dict[key]) + + return FrozenDict(layout_dict) + + +# load all_lvl_strs.json +# all_lvls_strs = json.load( +# open("jaxmarl/environments/overcooked/10x15_all_lvl_strs.json", "r")) +# gan_mlp_layouts = all_lvls_strs["gan_milp"] + +# automatic_overcooked_layouts_10_15 = { +# str(k): layout_grid_to_dict(v) for k, v in enumerate(gan_mlp_layouts) +# } + +# all_lvls_strs = json.load( +# open("jaxmarl/environments/overcooked/6x9_all_lvl_strs.json", "r")) +# gan_mlp_layouts = all_lvls_strs["gan_milp"] + +# automatic_overcooked_layouts_6_9 = { +# str(k): layout_grid_to_dict(v) for k, v in enumerate(gan_mlp_layouts) +# } + +# all_lvls_strs = json.load( +# open("jaxmarl/environments/overcooked/6x9_all_lvl_strs_simple.json", "r")) + +# gan_mlp_layouts = all_lvls_strs["gan_milp"] + +# automatic_overcooked_layouts_6_9_simple = { +# str(k): layout_grid_to_dict(v) for k, v in enumerate(gan_mlp_layouts) +# } + + +overcooked_layouts = { + "cramped_room": FrozenDict(cramped_room), + "asymm_advantages": FrozenDict(asymm_advantages), + "asymm_advantages_m": FrozenDict(asymm_advantages), + "coord_ring": FrozenDict(coord_ring), + "forced_coord": FrozenDict(forced_coord), + "counter_circuit": layout_grid_to_dict(counter_circuit_grid), + "asymm_advantages_6_9": layout_grid_to_dict(asymm_advantages_6_9), + "counter_circuit_6_9": layout_grid_to_dict(counter_circuit_6_9), + "forced_coord_6_9": layout_grid_to_dict(forced_coord_6_9), + "cramped_room_6_9": layout_grid_to_dict(cramped_room_6_9), + "coord_ring_6_9": layout_grid_to_dict(coord_ring_6_9), + "quad_6_9": layout_grid_to_dict(quad_6_9), + "quad_6_9_1": layout_grid_to_dict(quad_6_9_1), + "quad_6_9_2": layout_grid_to_dict(quad_6_9_2), + "quad_6_9_3": layout_grid_to_dict(quad_6_9_3), + "quad_6_9_4": layout_grid_to_dict(quad_6_9_4), + "quad_6_9_5": layout_grid_to_dict(quad_6_9_5), + "quad_6_9_6": layout_grid_to_dict(quad_6_9_6), + "quad_6_9_7": layout_grid_to_dict(quad_6_9_7), + "quad_6_9_8": layout_grid_to_dict(quad_6_9_8), + "quad_6_9_9": layout_grid_to_dict(quad_6_9_9), + "quad_6_9T": layout_grid_to_dict(quad_6_9T), + "quad_6_9M": layout_grid_to_dict(quad_6_9M), + "asymm_advantages_m": layout_grid_to_dict(asymm_advantages_m), + "asymm_advantages_m1": layout_grid_to_dict(asymm_advantages_m1), + "asymm_advantages_m2": layout_grid_to_dict(asymm_advantages_m2), + "asymm_advantages_m3": layout_grid_to_dict(asymm_advantages_m3), + "asymm_advantages_m4": layout_grid_to_dict(asymm_advantages_m4), + "asymm_advantages_m5": layout_grid_to_dict(asymm_advantages_m5), + "asymm_advantages_m6": layout_grid_to_dict(asymm_advantages_m6), + "asymm_advantages_m7": layout_grid_to_dict(asymm_advantages_m7), + "asymm_advantages_m8": layout_grid_to_dict(asymm_advantages_m8), + "asymm_advantages_10_15": layout_grid_to_dict(asymm_advantages_10_15), + "asymm_advantages_B_10_15": layout_grid_to_dict(asymm_advantages_B_10_15), + "asymm_advantages_M_10_15": layout_grid_to_dict(asymm_advantages_M_10_15), + "asymm_advantages_L_10_15": layout_grid_to_dict(asymm_advantages_L_10_15), +} diff --git a/src/minimax/envs/overcooked_proc/overcooked.py b/src/minimax/envs/overcooked_proc/overcooked.py new file mode 100644 index 0000000..316dfb9 --- /dev/null +++ b/src/minimax/envs/overcooked_proc/overcooked.py @@ -0,0 +1,1389 @@ +# Edited from JaxMarl: https://github.com/FLAIROx/JaxMARL/tree/main/jaxmarl/environments/overcooked + +from enum import IntEnum +from hmac import new +import os +import random +import time + +import imageio +import numpy as np +import jax +import jax.numpy as jnp +from jax import lax +from typing import Tuple, Dict +import chex +from flax import struct +from sympy import jn + +from minimax.envs import environment, spaces +from minimax.envs.registration import register +from minimax.envs.overcooked_proc.layouts import layout_grid_to_onehot_dict +import minimax.util.graph as _graph_util +from minimax.envs.viz.overcooked_visualizer import OvercookedVisualizer + +from .common import EnvInstance, make_overcooked_map + +asymm_advantages_6_9 = """ +WWWWWWWWW +O WXWOW X +W P A W +WA P W +WWWBWBWWW +WWWWWWWWW +""" + +counter_circuit_6_9 = """ +WWWPPWWWW +W A WW +B WWWW XW +W AWW +WWWOOWWWW +WWWWWWWWW +""" + +forced_coord_6_9 = """ +WWWPWWWWW +OAWAPWWWW +O W WWWWW +B W WWWWW +WWWXWWWWW +WWWWWWWWW +""" + +cramped_room_6_9 = """ +WWPWWWWWW +OAA OWWWW +W WWWWW +WBWXWWWWW +WWWWWWWWW +WWWWWWWWW +""" + +coord_ring_6_9 = """ +WWWPWWWWW +WA APWWWW +B W WWWWW +O WWWWW +WOXWWWWWW +WWWWWWWWW +""" + +forced_coord_5_5 = """ +WWWPW +OAWAP +O W W +B W W +WWWXW +""" + +cramped_room_5_5 = """ +WWPWW +OAA O +W W +WBWXW +WWWWW +""" + +coord_ring_5_5 = """ +WWWPW +WA AP +B W W +O W +WOXWW +""" + + +class Actions(IntEnum): + # Turn left, turn right, move forward + right = 0 + down = 1 + left = 2 + up = 3 + stay = 4 + interact = 5 + done = 6 + + +@struct.dataclass +class EnvState: + agent_pos: chex.Array + agent_dir: chex.Array + agent_dir_idx: chex.Array + agent_inv: chex.Array + goal_pos: chex.Array + pot_pos: chex.Array + wall_map: chex.Array + maze_map: chex.Array + bowl_pile_pos: chex.Array + onion_pile_pos: chex.Array + time: int + terminal: bool + +@struct.dataclass +class EnvParams: + height: int = 6 + width: int = 9 + h_min: int = 4 + w_min: int = 4 + n_walls: int = 5 + agent_view_size: int = 5 + replace_wall_pos: bool = False + normalize_obs: bool = False + sample_n_walls: bool = False # Sample n_walls uniformly in [0, n_walls] + max_steps: int = 400 + singleton_seed: int = -1 + max_episode_steps: int = 400 + + +# Pot status indicated by an integer, which ranges from 23 to 0 +POT_EMPTY_STATUS = 23 # 22 = 1 onion in pot; 21 = 2 onions in pot; 20 = 3 onions in pot +# 3 onions. Below this status, pot is cooking, and status acts like a countdown timer. +POT_FULL_STATUS = 20 +POT_READY_STATUS = 0 +# A pot has at most 3 onions. A soup contains exactly 3 onions. +MAX_ONIONS_IN_POT = 3 + +URGENCY_CUTOFF = 40 # When this many time steps remain, the urgency layer is flipped on +DELIVERY_REWARD = 20 + + +SHAPED_REWARD = { + "PLACEMENT_IN_POT_REW": 0, + "DISH_PICKUP_REWARD": 3, + "SOUP_PICKUP_REWARD": 5, + "PICKUP_TOMATO_REWARD": 0, + "DISH_DISP_DISTANCE_REW": 0, + "POT_DISTANCE_REW": 0, + "SOUP_DISTANCE_REW": 0, +} + +OBJECT_TO_INDEX = { + "unseen": 0, + "empty": 1, + "wall": 2, + "onion": 3, + "onion_pile": 4, + "plate": 5, + "plate_pile": 6, + "goal": 7, + "pot": 8, + "dish": 9, + "agent": 10, +} + + +COLORS = { + 'red': np.array([255, 0, 0]), + 'green': np.array([0, 255, 0]), + 'blue': np.array([0, 0, 255]), + 'purple': np.array([112, 39, 195]), + 'yellow': np.array([255, 255, 0]), + 'grey': np.array([100, 100, 100]), + 'white': np.array([255, 255, 255]), + 'black': np.array([25, 25, 25]), + 'orange': np.array([230, 180, 0]), +} + + +COLOR_TO_INDEX = { + 'red': 0, + 'green': 1, + 'blue': 2, + 'purple': 3, + 'yellow': 4, + 'grey': 5, + 'white': 6, + 'black': 7, + 'orange': 8, +} + +LAYOUT_STR_TO_LAYOUT = { + "asymm_advantages_6_9": asymm_advantages_6_9, + "counter_circuit_6_9": counter_circuit_6_9, + "forced_coord_6_9": forced_coord_6_9, + "cramped_room_6_9": cramped_room_6_9, + "coord_ring_6_9": coord_ring_6_9, + "coord_ring_5_5": coord_ring_5_5, + "forced_coord_5_5": forced_coord_5_5, + "cramped_room_5_5": cramped_room_5_5, +} + + +OBJECT_INDEX_TO_VEC = jnp.array([ + jnp.array([OBJECT_TO_INDEX['unseen'], 0, 0], dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['empty'], 0, 0], dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['wall'], COLOR_TO_INDEX['grey'], 0], + dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['onion'], + COLOR_TO_INDEX["yellow"], 0], dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['onion_pile'], + COLOR_TO_INDEX["yellow"], 0], dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['plate'], + COLOR_TO_INDEX["white"], 0], dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['plate_pile'], + COLOR_TO_INDEX["white"], 0], dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['goal'], COLOR_TO_INDEX['green'], 0], + dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['pot'], COLOR_TO_INDEX['black'], 0], + dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['dish'], COLOR_TO_INDEX["white"], 0], + dtype=jnp.uint8), + jnp.array([OBJECT_TO_INDEX['agent'], COLOR_TO_INDEX['red'], 0], + dtype=jnp.uint8), # Default color and direction +]) + + +# Map of agent direction indices to vectors +DIR_TO_VEC = jnp.array([ + # Pointing right (positive X) + # (1, 0), # right + # (0, 1), # down + # (-1, 0), # left + # (0, -1), # up + (0, -1), # NORTH + (0, 1), # SOUTH + (1, 0), # EAST + (-1, 0), # WEST +], dtype=jnp.int8) + + +def _obtain_from_layout(key, layout, h, w, random_reset, num_agents): + all_pos = np.arange(np.prod([h, w]), dtype=jnp.uint32) + occupied_mask = layout.get("wall_idx") + # occupied_mask = jnp.zeros_like(all_pos) + # occupied_mask = occupied_mask.at[wall_idx].set(1) + wall_map = occupied_mask.reshape(h, w).astype(jnp.bool_) + + # Reset agent position + dir + key, subkey = jax.random.split(key) + agent_idx = jax.random.choice(subkey, all_pos, shape=(num_agents,), + p=(~occupied_mask.astype(jnp.bool_)).astype(jnp.uint8), replace=False) + # agent_idx = jnp.zeros_like(occupied_mask).at[agent_idx].set(1) + + # Replace with fixed layout if applicable. Also randomize if agent position not provided + # agent_idx = random_reset*agent_idx + \ # (1-random_reset)* + agent_idx = layout.get("agent_idx", agent_idx) + agent_pos = jnp.array([agent_idx % w, agent_idx // w], + dtype=jnp.uint32).transpose() # dim = n_agents x 2 + # agent_pos = agent_idx.reshape(h,w) + occupied_mask = occupied_mask.at[agent_idx].set(1) + + key, subkey = jax.random.split(key) + agent_dir_idx = jax.random.choice(subkey, jnp.arange( + len(DIR_TO_VEC), dtype=jnp.int32), shape=(num_agents,)) + agent_dir = DIR_TO_VEC.at[agent_dir_idx].get() # dim = n_agents x 2 + + empty_table_mask = jnp.zeros_like(all_pos) + empty_table_mask = jnp.array(layout.get("empty_table_idx")).reshape(h, w) + + goal_idx = layout.get("goal_idx") + goal_pos = goal_idx.reshape(h, w) + empty_table_mask = empty_table_mask.at[goal_idx].set(0) + + onion_pile_idx = layout.get("onion_pile_idx") + onion_pile_pos = onion_pile_idx.reshape(h, w) + empty_table_mask = empty_table_mask.at[onion_pile_idx].set(0) + + plate_pile_idx = layout.get("plate_pile_idx") + plate_pile_pos = plate_pile_idx.reshape(h, w) + empty_table_mask = empty_table_mask.at[plate_pile_idx].set(0) + + pot_idx = layout.get("pot_idx") + pot_pos = pot_idx.reshape(h, w) + empty_table_mask = empty_table_mask.at[pot_idx].set(0) + + key, subkey = jax.random.split(key) + pot_status = pot_idx * \ + jax.random.randint(subkey, (pot_idx.shape[0],), 0, 24, dtype=jnp.uint8) + pot_status = pot_status * random_reset + \ + (1-random_reset) * jnp.ones((pot_idx.shape[0]), dtype=jnp.uint8) * 23 + return wall_map, goal_pos, agent_pos, agent_dir, agent_dir_idx, plate_pile_pos, onion_pile_pos, pot_pos, pot_status + + +class Overcooked(environment.Environment): + """Overcooked Procedural Multi-Agent""" + + def __init__( + self, + height: int, + width: int, + random_reset: bool = False, + n_walls=25, + agent_view_size=5, + replace_wall_pos=False, + max_steps=400, + normalize_obs=False, + sample_n_walls=False, + fix_to_single_layout=None, + dense_obs=False, + singleton_seed=-1 + ): + # Sets self.num_agents to 2 + super().__init__() + + self.num_agents = 2 + self.default_shaped_reward_coeff = 0.0 + # self.obs_shape = (agent_view_size, agent_view_size, 3) + # Observations given by 26 channels, most of which are boolean masks + # The idea is that we never create levels biger that this for zero padding. + self.width = width + self.height = height + self.num_features = 62 # Akin to the original Overcooked-AI + + # Hard coded. Only affects map padding -- not observations. + self.agent_view_size = 5 + self.agents = ["agent_0", "agent_1"] + # Fixes Resets to this layout instead to a random one. + # Mostly used for debugging. + # Example: "asymm_advantages_6_9" -> asymm_advantages_6_9 + self.fix_to_single_layout = fix_to_single_layout + self.dense_obs = dense_obs + + # Define the observation function + if dense_obs: # (62,) + self.get_obs = self.get_obs_dense + self.obs_shape = (self.num_features,) + else: # (h, w, 26,) + self.get_obs = self.get_obs_sparse + self.obs_shape = (self.width, self.height, 26) + + self.action_set = jnp.array([ + Actions.right, + Actions.down, + Actions.left, + Actions.up, + Actions.stay, + Actions.interact, + ]) + + self.random_reset = random_reset + self.max_steps = max_steps + + self.params = EnvParams( + height=height, + width=width, + n_walls=n_walls, + agent_view_size=agent_view_size, + replace_wall_pos=replace_wall_pos and not sample_n_walls, + max_steps=max_steps, + normalize_obs=normalize_obs, + sample_n_walls=sample_n_walls, + singleton_seed=-1, + max_episode_steps=max_steps, + ) + + def step_env( + self, + key: chex.PRNGKey, + state: EnvState, + actions: Dict[str, chex.Array], + ) -> Tuple[Dict[str, chex.Array], EnvState, Dict[str, float], Dict[str, bool], Dict]: + """Perform single timestep state transition.""" + + acts = self.action_set.take(indices=jnp.array( + [actions["agent_0"], actions["agent_1"]])) + + state, reward, shaped_reward_alice, shaped_reward_bob = self.step_agents( + key, state, acts) + + state = state.replace(time=state.time + 1) + + done = self.is_terminal(state) + state = state.replace(terminal=done) + + obs = self.get_obs(state) + rewards = { + "agent_0": reward, + "agent_1": reward + } + dones = {"agent_0": done, "agent_1": done, "__all__": done} + + return ( + lax.stop_gradient(obs), + lax.stop_gradient(state), + rewards, + dones, + { + "sparse_reward": jnp.array([reward, reward]), + "shaped_reward": jnp.array([shaped_reward_alice, shaped_reward_bob]), + }, + ) + + def sample_random_layout(self, key: chex.PRNGKey, h, w) -> Dict[str, chex.Array]: + """Samples a random layout that might or might not be playable. + """ + params = self.params + + all_pos = np.arange(np.prod([h, w]), dtype=jnp.uint8) + + key, walls_key, nwalls_key, goal_key, plate_pile_key, onion_pile_key, pot_key, agpos_key = jax.random.split( + key, 8) + wall_idx = jax.random.choice( + walls_key, all_pos, + shape=(params.n_walls,), + replace=params.replace_wall_pos) + + if params.sample_n_walls: + sampled_n_walls = jax.random.randint( + nwalls_key, (), minval=0, maxval=params.n_walls) + sample_wall_mask = jnp.arange(params.n_walls) < sampled_n_walls + dummy_wall_idx = wall_idx.at[0].get().repeat(params.n_walls) + wall_idx = jax.lax.select( + sample_wall_mask, + wall_idx, + dummy_wall_idx + ) + + walls = jnp.zeros_like(all_pos, dtype=jnp.uint8) + walls = walls.at[wall_idx].set(1) + walls = walls.reshape(h, w) + walls = walls.at[:, 0].set(1) + walls = walls.at[0, :].set(1) + walls = walls.at[:, -1].set(1) + walls = walls.at[-1, :].set(1).reshape(-1) + + occupied_obj_mask = jnp.zeros_like(all_pos, dtype=jnp.uint8) + wall_mask = occupied_obj_mask + walls + + # Do not want corners to have objects + occupied_obj_mask = occupied_obj_mask.reshape(h, w) + occupied_obj_mask = occupied_obj_mask.at[0, 0].set(1) + occupied_obj_mask = occupied_obj_mask.at[-1, -1].set(1) + occupied_obj_mask = occupied_obj_mask.at[0, -1].set(1) + occupied_obj_mask = occupied_obj_mask.at[-1, 0].set(1) + occupied_obj_mask = occupied_obj_mask.reshape(-1) + + def add_1_or_2_items(key, all_pos, wall_mask, occupied_obj_mask): + # occupied_obj_mask is only objects on tables so we can do: + possible_positions = wall_mask - occupied_obj_mask + obj_mask = jnp.zeros_like(all_pos, dtype=jnp.uint8) + key, subkey1, subkey2, subkey3 = jax.random.split(key, 4) + item_idx_1 = jax.random.choice(subkey1, all_pos, shape=( + 1,), p=(possible_positions.astype(jnp.bool_)).astype(jnp.uint8)) + + and_2 = jax.random.bernoulli(subkey2, 0.5) + + item_idx_2 = jax.random.choice(subkey3, all_pos, shape=( + 1,), p=(possible_positions.astype(jnp.bool_)).astype(jnp.uint8)) + + obj_mask = obj_mask.at[item_idx_1].set(1) + + update_2 = jnp.logical_or( + obj_mask.at[item_idx_2].get(), and_2.astype(jnp.uint8)) + obj_mask = obj_mask.at[item_idx_2].set(update_2) + return obj_mask + + goal_pos = add_1_or_2_items( + goal_key, all_pos, wall_mask, occupied_obj_mask) + occupied_obj_mask = occupied_obj_mask + goal_pos + + plate_pile_pos = add_1_or_2_items( + plate_pile_key, all_pos, wall_mask, occupied_obj_mask) + occupied_obj_mask = occupied_obj_mask + plate_pile_pos + + onion_pile_pos = add_1_or_2_items( + onion_pile_key, all_pos, wall_mask, occupied_obj_mask) + occupied_obj_mask = occupied_obj_mask + onion_pile_pos + + pot_pos = add_1_or_2_items( + pot_key, all_pos, wall_mask, occupied_obj_mask) + occupied_obj_mask = occupied_obj_mask + pot_pos + + agent_idx = jax.random.choice(agpos_key, all_pos, shape=(2,), replace=False, p=( + ~wall_mask.astype(jnp.bool_)).astype(jnp.uint8)) + # occupied_mask = occupied_mask.at[agent_idx].set(2) + + layout = { + "height": self.height, + "width": self.width, + "wall_idx": walls, + "empty_table_idx": walls - occupied_obj_mask, + "agent_idx": agent_idx, + "goal_idx": goal_pos, + "plate_pile_idx": plate_pile_pos, + "onion_pile_idx": onion_pile_pos, + "pot_idx": pot_pos + } + return layout + + def reset_env( # NOTE: Has been renamed to fit minimax + self, + key: chex.PRNGKey, + ) -> Tuple[Dict[str, chex.Array], EnvState]: + """Reset environment state based on `self.random_reset` + + If True, everything is randomized, including agent inventories and positions, pot states and items on counters + If False, only resample agent orientations + + In both cases, the environment layout is determined by `self.layout` + """ + # Whether to fully randomize the start state + random_reset = self.random_reset + + h = self.height + w = self.width + num_agents = self.num_agents + + if self.fix_to_single_layout is None: + layout = self.sample_random_layout(key, h, w) + else: + layout = layout_grid_to_onehot_dict( + LAYOUT_STR_TO_LAYOUT[self.fix_to_single_layout]) + + wall_map, goal_pos, agent_pos, agent_dir, agent_dir_idx, plate_pile_pos, onion_pile_pos, pot_pos, pot_status\ + = _obtain_from_layout(key, layout, h, w, random_reset, num_agents) + + onion_pos = jnp.zeros((h, w), dtype=jnp.uint8) + plate_pos = jnp.zeros((h, w), dtype=jnp.uint8) + dish_pos = jnp.zeros((h, w), dtype=jnp.uint8) + + maze_map = make_overcooked_map( + wall_map, + goal_pos, + agent_pos, + agent_dir_idx, + plate_pile_pos, + onion_pile_pos, + pot_pos, + pot_status, + onion_pos, + plate_pos, + dish_pos, + pad_obs=True, + num_agents=self.num_agents, + agent_view_size=self.agent_view_size + ) + # Its to make padding static with respect to the jitted code later. + # Its static since we compute it in advance now. + padding = (maze_map.shape[0]-h) // 2 + + # agent inventory (empty by default, can be randomized) + key, subkey = jax.random.split(key) + possible_items = jnp.array([OBJECT_TO_INDEX['empty'], OBJECT_TO_INDEX['onion'], + OBJECT_TO_INDEX['plate'], OBJECT_TO_INDEX['dish']]) + random_agent_inv = jax.random.choice( + subkey, possible_items, shape=(num_agents,), replace=True) + agent_inv = random_reset * random_agent_inv + \ + (1-random_reset) * \ + jnp.array([OBJECT_TO_INDEX['empty'], OBJECT_TO_INDEX['empty']]) + + state = EnvState( + agent_pos=agent_pos, + agent_dir=agent_dir, + agent_dir_idx=agent_dir_idx, + agent_inv=agent_inv, + goal_pos=goal_pos, + pot_pos=pot_pos, + onion_pile_pos=onion_pile_pos, + bowl_pile_pos=plate_pile_pos, + wall_map=wall_map.astype(jnp.bool_), + maze_map=maze_map, + time=0, + terminal=False, + ) + + self.padding = padding + obs = self.get_obs(state) + + return lax.stop_gradient(obs), lax.stop_gradient(state) + + def get_obs_dense(self, state: EnvState) -> Dict[str, chex.Array]: + """ + Inspired by the original Overcooked-AI we also add a dense observation to the environment. + We use this to built the OvercookedUED light challange as it is significantly less sparse then the original observation. + + From their doc (https://github.com/HumanCompatibleAI/overcooked_ai/blob/cff884ccf5709658ee4cd489e63367200b4c86d6/src/overcooked_ai_py/mdp/overcooked_mdp.py#L2579): + Returns: + ordered_features (list[np.Array]): The ith element contains a player-centric featurized view for the ith player + + The encoding for player i is as follows: + + [player_i_features, other_player_features player_i_dist_to_other_players, player_i_position] + + player_{i}_features (length num_pots*10 + 24): + pi_orientation: length 4 one-hot-encoding of direction currently facing + pi_obj: length 4 one-hot-encoding of object currently being held (all 0s if no object held) + pi_wall_{j}: {0, 1} boolean value of whether player i has wall immediately in direction j + pi_closest_{onion|tomato|dish|soup|serving|empty_counter}: (dx, dy) where dx = x dist to item, dy = y dist to item. (0, 0) if item is currently held + pi_cloest_soup_n_{onions|tomatoes}: int value for number of this ingredient in closest soup + pi_closest_pot_{j}_exists: {0, 1} depending on whether jth closest pot found. If 0, then all other pot features are 0. Note: can + be 0 even if there are more than j pots on layout, if the pot is not reachable by player i + pi_closest_pot_{j}_{is_empty|is_full|is_cooking|is_ready}: {0, 1} depending on boolean value for jth closest pot + pi_closest_pot_{j}_{num_onions|num_tomatoes}: int value for number of this ingredient in jth closest pot + pi_closest_pot_{j}_cook_time: int value for seconds remaining on soup. -1 if no soup is cooking + pi_closest_pot_{j}: (dx, dy) to jth closest pot from player i location + + other_player_features (length (num_players - 1)*(num_pots*10 + 24)): + ordered concatenation of player_{j}_features for j != i + + player_i_dist_to_other_players (length (num_players - 1)*2): + [player_j.pos - player_i.pos for j != i] + + player_i_position (length 2) + """ + agent_dir = state.agent_dir + agent_inv = state.agent_inv + maze_map = state.maze_map + + w = self.width + h = self.height + + padding = 4 + maze_map = maze_map[padding:-padding, padding:-padding, :] + + def get_player_rep(player_idx: int, state): + + agent_pos = state.agent_pos[player_idx] + + # pi_orientation: length 4 one-hot-encoding of direction currently facing + pi_orientation = jnp.zeros((4)).at[state.agent_dir_idx].set(1) + + # pi_obj: length 3 one-hot-encoding of object currently being held (all 0s if no object held) + pi_obj = OBJECT_TO_INDEX["empty"] + (agent_inv[player_idx] == OBJECT_TO_INDEX["onion"]) * jnp.array([1, 0, 0], dtype=jnp.uint8)\ + + (agent_inv[player_idx] == OBJECT_TO_INDEX["plate"]) * jnp.array([0, 1, 0], dtype=jnp.uint8)\ + + (agent_inv[player_idx] == OBJECT_TO_INDEX["dish"] + ) * jnp.array([0, 0, 1], dtype=jnp.uint8) + + # pi_wall_{j}: {0, 1} boolean value of whether player i has wall immediately in direction j + fwd_pos_0 = agent_pos + DIR_TO_VEC[0] + is_wall_0 = state.wall_map.at[fwd_pos_0[1], fwd_pos_0[0]].get() + + fwd_pos_1 = agent_pos + DIR_TO_VEC[1] + is_wall_1 = state.wall_map.at[fwd_pos_1[1], fwd_pos_1[0]].get() + + fwd_pos_2 = agent_pos + DIR_TO_VEC[2] + is_wall_2 = state.wall_map.at[fwd_pos_2[1], fwd_pos_2[0]].get() + + fwd_pos_3 = agent_pos + DIR_TO_VEC[3] + is_wall_3 = state.wall_map.at[fwd_pos_3[1], fwd_pos_3[0]].get() + + pi_wall_j = jnp.array([is_wall_0, is_wall_1, is_wall_2, is_wall_3]) + + # pi_closest_{onion|dish|soup|serving|empty_counter}: (dx, dy) where dx = x dist to item, dy = y dist to item. (0, 0) if item is currently held + def find_closest_between_masks(agent_pos, object_map, name): + obj_idx = OBJECT_TO_INDEX[name] + padded_pos = jnp.argwhere( + object_map.T == obj_idx, size=2, # w*h, + fill_value=jnp.inf) + dist = padded_pos-agent_pos + abs_dist = jnp.abs(dist) + manhatten = abs_dist.sum(-1) + closest_idx = jnp.argmin(manhatten) + clostest_obj_pos = padded_pos[closest_idx] + dxdy_obj_ag_inf = dist[closest_idx] + dxdy_obj_ag = jnp.nan_to_num(dxdy_obj_ag_inf, nan=0, posinf=0) + return clostest_obj_pos.astype(jnp.uint8), dxdy_obj_ag + + object_map = maze_map[:, :, 0] + pos_closest_pot, pi_closest_pot = find_closest_between_masks( + agent_pos, object_map, "pot") + _, pi_closest_onion = find_closest_between_masks( + agent_pos, object_map, "onion") + _, pi_closest_plate = find_closest_between_masks( + agent_pos, object_map, "plate") + _, pi_closest_dish = find_closest_between_masks( + agent_pos, object_map, "dish") + _, pi_closest_goal = find_closest_between_masks( + agent_pos, object_map, "goal") + # If it has something on it its type is not wall -> i.e. walls are always empty + # empty_wall_map = (maze_map[:,:,0] == OBJECT_TO_INDEX["wall"]).astype(jnp.uint8) + _, pi_closest_wall = find_closest_between_masks( + agent_pos, object_map, "wall") + + # pi_cloest_soup_n_{onions}: int value for number of this ingredient in closest soup + # Not apllicable: We only have 3 onion soups + # pi_closest_pot_{j}_exists: {0, 1} depending on whether jth closest pot found. If 0, then all other pot features are 0. Note: can + # be 0 even if there are more than j pots on layout, if the pot is not reachable by player i + # pi_closest_pot_{j}_{is_empty|is_full|is_cooking|is_ready}: {0, 1} depending on boolean value for jth closest pot + # pi_closest_pot_{j}_{num_onions|num_tomatoes}: int value for number of this ingredient in jth closest pot + # pi_closest_pot_{j}_cook_time: int value for seconds remaining on soup. -1 if no soup is cooking + # pi_closest_pot_{j}: (dx, dy) to jth closest pot from player i location + closest_pot = maze_map.at[pos_closest_pot[1], + pos_closest_pot[0]].get() + + # agent_obj = maze_map.at[agent_pos[1], agent_pos[0]].get() + + path_len = _graph_util.shortest_path_len( + state.wall_map, agent_pos, pos_closest_pot) + + # pi_closest_pot_{j}_exists + pi_closest_pot_exists = path_len > 0 + pi_closest_pot_is_empty = ( + closest_pot[2] == 23) * pi_closest_pot_exists + pi_closest_pot_is_full = ( + jnp.logical_and(closest_pot[2] <= 20, closest_pot[2] > 0)) * pi_closest_pot_exists + pi_closest_pot_is_cooking = ( + jnp.logical_and(closest_pot[2] <= 19, closest_pot[2] > 0)) * pi_closest_pot_exists + pi_closest_pot_is_ready = ( + closest_pot[2] == 0) * pi_closest_pot_exists + pi_closest_pot_num_onions = ( + (closest_pot[2] <= 20)*3 + (closest_pot[2] == 21)*2 + (closest_pot[2] == 22)*1) * pi_closest_pot_exists + pi_closest_pot_cook_time = pi_closest_pot_is_cooking * \ + closest_pot[2] + + return jnp.hstack([ + pi_orientation, pi_obj, pi_wall_j, pi_closest_onion, pi_closest_plate, pi_closest_dish, + pi_closest_goal, pi_closest_wall, pi_closest_pot_exists, pi_closest_pot_is_empty, pi_closest_pot_is_full, + pi_closest_pot_is_cooking, pi_closest_pot_is_ready, pi_closest_pot_num_onions, pi_closest_pot_cook_time, pi_closest_pot + ]) + + agent_vec_0 = get_player_rep(0, state) + agent_vec_1 = get_player_rep(1, state) + + obs = { + 'agent_0': jnp.hstack([agent_vec_0, agent_vec_1, state.agent_pos[0, 1], state.agent_pos[0, 0]]), + 'agent_1': jnp.hstack([agent_vec_1, agent_vec_0, state.agent_pos[1, 1], state.agent_pos[1, 0]]) + } + return obs + + def get_obs_sparse(self, state: EnvState) -> Dict[str, chex.Array]: + """Return a full observation, of size(height x width x n_layers), where n_layers = 26. + Layers are of shape(height x width) and are binary(0/1) except where indicated otherwise. + The obs is very sparse(most elements are 0), which prob. contributes to generalization problems in Overcooked. + A v2 of this environment should have much more efficient observations, e.g. using item embeddings + + The list of channels is below. Agent-specific layers are ordered so that an agent perceives its layers first. + Env layers are the same (and in same order) for both agents. + + Agent positions: + 0. position of agent i(1 at agent loc, 0 otherwise) + 1. position of agent(1-i) + + Agent orientations: + 2-5. agent_{i}_orientation_0 to agent_{i}_orientation_3(layers are entirely zero except for the one orientation + layer that matches the agent orientation. That orientation has a single 1 at the agent coordinates.) + 6-9. agent_{i-1}_orientation_{dir} + + Static env positions(1 where object of type X is located, 0 otherwise.): + 10. pot locations + 11. counter locations(table) + 12. onion pile locations + 13. tomato pile locations(tomato layers are included for consistency, but this env does not support tomatoes) + 14. plate pile locations + 15. delivery locations(goal) + + Pot and soup specific layers. These are non-binary layers: + 16. number of onions in pot(0, 1, 2, 3) for elements corresponding to pot locations. Nonzero only for pots that + have NOT started cooking yet. When a pot starts cooking (or is ready), the corresponding element is set to 0 + 17. number of tomatoes in pot. + 18. number of onions in soup(0, 3) for elements corresponding to either a cooking/done pot or to a soup(dish) + ready to be served. This is a useless feature since all soups have exactly 3 onions, but it made sense in the + full Overcooked where recipes can be a mix of tomatoes and onions + 19. number of tomatoes in soup + 20. pot cooking time remaining. [19 -> 1] for pots that are cooking. 0 for pots that are not cooking or done + 21. soup done. (Binary) 1 for pots done cooking and for locations containing a soup(dish). O otherwise. + + Variable env layers(binary): + 22. plate locations + 23. onion locations + 24. tomato locations + + Urgency: + 25. Urgency. The entire layer is 1 there are 40 or fewer remaining time steps. 0 otherwise + """ + width = self.obs_shape[0] + height = self.obs_shape[1] + n_channels = self.obs_shape[2] + # NOTE: Original code here was: padding = (state.maze_map.shape[0]-height) // 2 + padding = 4 + # padding = state.padding # Must be somehow static + + maze_map = state.maze_map[padding:-padding, padding:-padding, 0] + soup_loc = jnp.array( + maze_map == OBJECT_TO_INDEX["dish"], dtype=jnp.uint8) + + pot_loc_layer = jnp.array( + maze_map == OBJECT_TO_INDEX["pot"], dtype=jnp.uint8) + pot_status = state.maze_map[padding:-padding, + padding: -padding, 2] * pot_loc_layer + onions_in_pot_layer = jnp.minimum(POT_EMPTY_STATUS - pot_status, MAX_ONIONS_IN_POT) * ( + pot_status >= POT_FULL_STATUS) # 0/1/2/3, as long as not cooking or not done + onions_in_soup_layer = jnp.minimum(POT_EMPTY_STATUS - pot_status, MAX_ONIONS_IN_POT) * (pot_status < POT_FULL_STATUS) \ + * pot_loc_layer + MAX_ONIONS_IN_POT * soup_loc # 0/3, as long as cooking or done + pot_cooking_time_layer = pot_status * \ + (pot_status < POT_FULL_STATUS) # Timer: 19 to 0 + # Ready soups, plated or not + soup_ready_layer = pot_loc_layer * \ + (pot_status == POT_READY_STATUS) + soup_loc + urgency_layer = jnp.ones(maze_map.shape, dtype=jnp.uint8) * \ + ((self.max_steps - state.time) < URGENCY_CUTOFF) + + agent_pos_layers = jnp.zeros((2, height, width), dtype=jnp.uint8) + agent_pos_layers = agent_pos_layers.at[0, + state.agent_pos[0, 1], state.agent_pos[0, 0]].set(1) + agent_pos_layers = agent_pos_layers.at[1, + state.agent_pos[1, 1], state.agent_pos[1, 0]].set(1) + + # Add agent inv: This works because loose items and agent cannot overlap + agent_inv_items = jnp.expand_dims( + state.agent_inv, (1, 2)) * agent_pos_layers + maze_map = jnp.where(jnp.sum(agent_pos_layers, 0), + agent_inv_items.sum(0), maze_map) + soup_ready_layer = soup_ready_layer + + (jnp.sum(agent_inv_items, 0) == + OBJECT_TO_INDEX["dish"]) * jnp.sum(agent_pos_layers, 0) + onions_in_soup_layer = onions_in_soup_layer \ + + (jnp.sum(agent_inv_items, 0) == + OBJECT_TO_INDEX["dish"]) * 3 * jnp.sum(agent_pos_layers, 0) + + env_layers = [ + # Channel 10 + jnp.array(maze_map == OBJECT_TO_INDEX["pot"], dtype=jnp.uint8), + jnp.array(maze_map == OBJECT_TO_INDEX["wall"], dtype=jnp.uint8), + jnp.array( + maze_map == OBJECT_TO_INDEX["onion_pile"], dtype=jnp.uint8), + # tomato pile + jnp.zeros(maze_map.shape, dtype=jnp.uint8), + jnp.array( + maze_map == OBJECT_TO_INDEX["plate_pile"], dtype=jnp.uint8), + # 15 + jnp.array(maze_map == OBJECT_TO_INDEX["goal"], dtype=jnp.uint8), + jnp.array(onions_in_pot_layer, dtype=jnp.uint8), + # tomatoes in pot + jnp.zeros(maze_map.shape, dtype=jnp.uint8), + jnp.array(onions_in_soup_layer, dtype=jnp.uint8), + # tomatoes in soup + jnp.zeros(maze_map.shape, dtype=jnp.uint8), + jnp.array(pot_cooking_time_layer, + dtype=jnp.uint8), # 20 + jnp.array(soup_ready_layer, dtype=jnp.uint8), + jnp.array(maze_map == OBJECT_TO_INDEX["plate"], dtype=jnp.uint8), + jnp.array(maze_map == OBJECT_TO_INDEX["onion"], dtype=jnp.uint8), + # tomatoes + jnp.zeros(maze_map.shape, dtype=jnp.uint8), + urgency_layer, # 25 + ] + + # Agent related layers + agent_direction_layers = jnp.zeros((8, height, width), dtype=jnp.uint8) + dir_layer_idx = state.agent_dir_idx+jnp.array([0, 4]) + agent_direction_layers = agent_direction_layers.at[dir_layer_idx, :, :].set( + agent_pos_layers) + + # Both agent see their layers first, then the other layer + alice_obs = jnp.zeros((n_channels, height, width), dtype=jnp.uint8) + alice_obs = alice_obs.at[0:2].set(agent_pos_layers) + + alice_obs = alice_obs.at[2:10].set(agent_direction_layers) + alice_obs = alice_obs.at[10:].set(jnp.stack(env_layers)) + + bob_obs = jnp.zeros((n_channels, height, width), dtype=jnp.uint8) + bob_obs = bob_obs.at[0].set( + agent_pos_layers[1]).at[1].set(agent_pos_layers[0]) + bob_obs = bob_obs.at[2:6].set(agent_direction_layers[4:]).at[6:10].set( + agent_direction_layers[0:4]) + bob_obs = bob_obs.at[10:].set(jnp.stack(env_layers)) + + # NOTE: Changed, was not inline with self.obs_shape: [self.width, self.height, 26] + alice_obs = jnp.transpose(alice_obs, (2, 1, 0)) + bob_obs = jnp.transpose(bob_obs, (2, 1, 0)) + return {"agent_0": alice_obs, "agent_1": bob_obs} + + def step_agents( + self, key: chex.PRNGKey, state: EnvState, action: chex.Array + ) -> Tuple[EnvState, float]: + + # Update agent position (forward action) + is_move_action = jnp.logical_and( + action != Actions.stay, action != Actions.interact) + is_move_action_transposed = jnp.expand_dims( + is_move_action, 0).transpose() # Necessary to broadcast correctly + + fwd_pos = jnp.minimum( + jnp.maximum(state.agent_pos + is_move_action_transposed * DIR_TO_VEC[jnp.minimum(action, 3)] + + ~is_move_action_transposed * state.agent_dir, 0), + jnp.array((self.width - 1, self.height - 1), dtype=jnp.uint32) + ) + + # Can't go past wall or goal + def _wall_or_goal(fwd_position, wall_map, goal_pos): + fwd_wall = wall_map.at[fwd_position[1], fwd_position[0]].get() + def goal_collision(pos, goal): return jnp.logical_and( + pos[0] == goal[0], pos[1] == goal[1]) + fwd_goal = jax.vmap(goal_collision, in_axes=( + None, 0))(fwd_position, goal_pos) + # fwd_goal = jnp.logical_and(fwd_position[0] == goal_pos[0], fwd_position[1] == goal_pos[1]) + fwd_goal = jnp.any(fwd_goal) + return fwd_wall, fwd_goal + + fwd_pos_has_wall, fwd_pos_has_goal = jax.vmap(_wall_or_goal, in_axes=( + 0, None, None))(fwd_pos, state.wall_map, state.goal_pos) + + fwd_pos_blocked = jnp.logical_or( + fwd_pos_has_wall, fwd_pos_has_goal).reshape((self.num_agents, 1)) + + bounced = jnp.logical_or(fwd_pos_blocked, ~is_move_action_transposed) + + # Agents can't overlap + # Hardcoded for 2 agents (call them Alice and Bob) + agent_pos_prev = jnp.array(state.agent_pos) + fwd_pos = (bounced * state.agent_pos + (~bounced) + * fwd_pos).astype(jnp.uint32) + collision = jnp.all(fwd_pos[0] == fwd_pos[1]) + + # No collision = No movement. This matches original Overcooked env. + alice_pos = jnp.where( + collision, + state.agent_pos[0], # collision and Bob bounced + fwd_pos[0], + ) + bob_pos = jnp.where( + collision, + # collision and Alice bounced + state.agent_pos[1], + fwd_pos[1], + ) + + # Prevent swapping places (i.e. passing through each other) + swap_places = jnp.logical_and( + jnp.all(fwd_pos[0] == state.agent_pos[1]), + jnp.all(fwd_pos[1] == state.agent_pos[0]), + ) + alice_pos = jnp.where( + ~collision * swap_places, + state.agent_pos[0], + alice_pos + ) + bob_pos = jnp.where( + ~collision * swap_places, + state.agent_pos[1], + bob_pos + ) + + fwd_pos = fwd_pos.at[0].set(alice_pos) + fwd_pos = fwd_pos.at[1].set(bob_pos) + agent_pos = fwd_pos.astype(jnp.uint32) + + # Update agent direction + agent_dir_idx = ~is_move_action * state.agent_dir_idx + is_move_action * action + agent_dir = DIR_TO_VEC[agent_dir_idx] + + # Handle interacts. Agent 1 first, agent 2 second, no collision handling. + # This matches the original Overcooked + fwd_pos = state.agent_pos + state.agent_dir + maze_map = state.maze_map + is_interact_action = (action == Actions.interact) + + # Compute the effect of interact first, then apply it if needed + candidate_maze_map, alice_inv, alice_reward, alice_shaped_reward = self.process_interact( + maze_map, state, fwd_pos[0], state.agent_inv[0], state.agent_inv[1]) + alice_interact = is_interact_action[0] + bob_interact = is_interact_action[1] + + maze_map = jax.lax.select(alice_interact, + candidate_maze_map, + maze_map) + alice_inv = jax.lax.select(alice_interact, + alice_inv, + state.agent_inv[0]) + alice_reward = jax.lax.select(alice_interact, alice_reward, 0.) + alice_shaped_reward = jax.lax.select( + alice_interact, alice_shaped_reward, 0.) + + candidate_maze_map, bob_inv, bob_reward, bob_shaped_reward = self.process_interact( + maze_map, state, fwd_pos[1], state.agent_inv[1], state.agent_inv[0]) + maze_map = jax.lax.select(bob_interact, + candidate_maze_map, + maze_map) + bob_inv = jax.lax.select(bob_interact, + bob_inv, + state.agent_inv[1]) + bob_reward = jax.lax.select(bob_interact, bob_reward, 0.) + bob_shaped_reward = jax.lax.select(bob_interact, bob_shaped_reward, 0.) + + agent_inv = jnp.array([alice_inv, bob_inv]) + + # Update agent component in maze_map + def _get_agent_updates(agent_dir_idx, agent_pos, agent_pos_prev, agent_idx): + agent = jnp.array([OBJECT_TO_INDEX['agent'], COLOR_TO_INDEX['red'] + + agent_idx*2, agent_dir_idx], dtype=jnp.uint8) + agent_x_prev, agent_y_prev = agent_pos_prev + agent_x, agent_y = agent_pos + return agent_x, agent_y, agent_x_prev, agent_y_prev, agent + + vec_update = jax.vmap(_get_agent_updates, in_axes=(0, 0, 0, 0)) + agent_x, agent_y, agent_x_prev, agent_y_prev, agent_vec = vec_update( + agent_dir_idx, agent_pos, agent_pos_prev, jnp.arange(self.num_agents)) + empty = jnp.array([OBJECT_TO_INDEX['empty'], 0, 0], dtype=jnp.uint8) + + # Compute padding, added automatically by map maker function + # height = self.obs_shape[1] + padding = 4 # (state.maze_map.shape[0] - height) // 2 + + maze_map = maze_map.at[padding + agent_y_prev, + padding + agent_x_prev, :].set(empty) + maze_map = maze_map.at[padding + agent_y, + padding + agent_x, :].set(agent_vec) + + # Update pot cooking status + def _cook_pots(maze_map, pot_pos): + pot_pos_padded = jnp.zeros( + (maze_map.shape[0], maze_map.shape[1]), dtype=jnp.uint8 + ) + pot_pos_padded = pot_pos_padded.at[ + padding:-padding, padding:-padding].set(pot_pos) + is_cooking = jnp.array( + maze_map[:, :, -1] * pot_pos_padded <= POT_FULL_STATUS, dtype=jnp.uint8) * pot_pos_padded + not_done = jnp.array( + maze_map[:, :, -1] * pot_pos_padded > POT_READY_STATUS, dtype=jnp.uint8) * pot_pos_padded + pot_status_is_cooking_not_done = is_cooking * \ + not_done * (maze_map[:, :, -1] - 1) * pot_pos_padded + pot_status_is_not_cooking = jnp.logical_not( + is_cooking) * (maze_map[:, :, -1]) * pot_pos_padded # defaults to zero if done pot_status + pot_status = pot_status_is_cooking_not_done + pot_status_is_not_cooking + + pot_status_map = pot_pos_padded * pot_status + \ + jnp.logical_not(pot_pos_padded) * maze_map[:, :, -1] + pot_status_map = jnp.concatenate( + (jnp.zeros((*pot_status_map.shape, 2), dtype=jnp.uint8), pot_status_map[:, :, jnp.newaxis]), axis=-1) + + pot_pos_3 = jnp.concatenate( + (jnp.zeros((pot_status_map.shape[0], pot_status_map.shape[1], 2), dtype=jnp.uint8), pot_pos_padded[:, :, jnp.newaxis]), axis=-1) + + maze_map = maze_map * (1-pot_pos_3) + pot_status_map * pot_pos_3 + + return maze_map # pot.at[-1].set(pot_status) + + maze_map = _cook_pots(maze_map, state.pot_pos) + + reward = alice_reward + bob_reward + # shaped_reward = alice_shaped_reward + bob_shaped_reward + + return ( + state.replace( + agent_pos=agent_pos, + agent_dir_idx=agent_dir_idx, + agent_dir=agent_dir, + agent_inv=agent_inv, + maze_map=maze_map, + terminal=False), + reward, + alice_shaped_reward, + bob_shaped_reward, + ) + + def process_interact( + self, + maze_map: chex.Array, + state: EnvState, + fwd_pos: chex.Array, + inventory: chex.Array, + other_inventory: chex.Array): + """Assume agent took interact actions. Result depends on what agent is facing and what it is holding.""" + + wall_map = state.wall_map + height = self.height # self.obs_shape[1] + # padding = (maze_map.shape[0] - height) // 2 + padding = 4 + + # Get object in front of agent (on the "table") + maze_object_on_table = maze_map.at[padding + + fwd_pos[1], padding + fwd_pos[0]].get() + object_on_table = maze_object_on_table[0] # Simple index + + # Booleans depending on what the object is + object_is_pile = jnp.logical_or( + object_on_table == OBJECT_TO_INDEX["plate_pile"], object_on_table == OBJECT_TO_INDEX["onion_pile"]) + object_is_pot = jnp.array(object_on_table == OBJECT_TO_INDEX["pot"]) + object_is_goal = jnp.array(object_on_table == OBJECT_TO_INDEX["goal"]) + object_is_agent = jnp.array( + object_on_table == OBJECT_TO_INDEX["agent"]) + object_is_pickable = jnp.logical_or( + jnp.logical_or( + object_on_table == OBJECT_TO_INDEX["plate"], object_on_table == OBJECT_TO_INDEX["onion"]), + object_on_table == OBJECT_TO_INDEX["dish"] + ) + # Whether the object in front is counter space that the agent can drop on. + is_table = jnp.logical_and( + wall_map.at[fwd_pos[1], fwd_pos[0]].get(), ~object_is_pot) + + table_is_empty = jnp.logical_or( + object_on_table == OBJECT_TO_INDEX["wall"], object_on_table == OBJECT_TO_INDEX["empty"]) + + # Pot status (used if the object is a pot) + pot_status = maze_object_on_table[-1] + + # Get inventory object, and related booleans + inv_is_empty = jnp.array(inventory == OBJECT_TO_INDEX["empty"]) + object_in_inv = inventory + holding_onion = jnp.array(object_in_inv == OBJECT_TO_INDEX["onion"]) + holding_plate = jnp.array(object_in_inv == OBJECT_TO_INDEX["plate"]) + holding_dish = jnp.array(object_in_inv == OBJECT_TO_INDEX["dish"]) + + # Interactions with pot. 3 cases: add onion if missing, collect soup if ready, do nothing otherwise + case_1 = (pot_status > POT_FULL_STATUS) * holding_onion * object_is_pot + case_2 = (pot_status == POT_READY_STATUS) * \ + holding_plate * object_is_pot + case_3 = (pot_status > POT_READY_STATUS) * \ + (pot_status <= POT_FULL_STATUS) * object_is_pot + else_case = ~case_1 * ~case_2 * ~case_3 + + # Update pot status and object in inventory + new_pot_status = \ + case_1 * (pot_status - 1) \ + + case_2 * POT_EMPTY_STATUS \ + + case_3 * pot_status \ + + else_case * pot_status + new_object_in_inv = \ + case_1 * OBJECT_TO_INDEX["empty"] \ + + case_2 * OBJECT_TO_INDEX["dish"] \ + + case_3 * object_in_inv \ + + else_case * object_in_inv + + # Interactions with onion/plate piles and objects on counter + # Pickup if: table, not empty, room in inv & object is not something unpickable (e.g. pot or goal) + successful_pickup = is_table * ~table_is_empty * inv_is_empty * \ + jnp.logical_or(object_is_pile, object_is_pickable) + successful_drop = is_table * table_is_empty * ~inv_is_empty + successful_delivery = is_table * object_is_goal * holding_dish + no_effect = jnp.logical_and(jnp.logical_and( + ~successful_pickup, ~successful_drop), ~successful_delivery) + + # Update object on table + new_object_on_table = \ + no_effect * object_on_table \ + + successful_delivery * object_on_table \ + + successful_pickup * object_is_pile * object_on_table \ + + successful_pickup * object_is_pickable * OBJECT_TO_INDEX["wall"] \ + + successful_drop * object_in_inv + + # Update object in inventory + new_object_in_inv = \ + no_effect * new_object_in_inv \ + + successful_delivery * OBJECT_TO_INDEX["empty"] \ + + successful_pickup * object_is_pickable * object_on_table \ + + successful_pickup * (object_on_table == OBJECT_TO_INDEX["plate_pile"]) * OBJECT_TO_INDEX["plate"] \ + + successful_pickup * (object_on_table == OBJECT_TO_INDEX["onion_pile"]) * OBJECT_TO_INDEX["onion"] \ + + successful_drop * OBJECT_TO_INDEX["empty"] + + # Apply inventory update + inventory = new_object_in_inv + + # Apply changes to maze + new_maze_object_on_table = \ + object_is_pot * OBJECT_INDEX_TO_VEC[new_object_on_table].at[-1].set(new_pot_status) \ + + ~object_is_pot * ~object_is_agent * OBJECT_INDEX_TO_VEC[new_object_on_table] \ + + object_is_agent * maze_object_on_table + + maze_map = maze_map.at[padding + fwd_pos[1], + padding + fwd_pos[0], :].set(new_maze_object_on_table) + + # Reward of 20 for a soup delivery + reward = jnp.array(successful_delivery, dtype=float)*DELIVERY_REWARD + + no_plate_on_counter = ( + (maze_map[padding:-padding, padding:-padding, 0] * wall_map) == OBJECT_TO_INDEX["plate"]).sum() == 0 + num_pots = state.pot_pos.sum() + # (maze_map[padding:-padding, padding:-padding, -1].at[state.pot_pos].get() <= POT_FULL_STATUS).sum() + num_pots_cooking = ( + (maze_map[padding:-padding, padding:-padding, -1] <= POT_FULL_STATUS) * state.pot_pos).sum() + # (maze_map[padding:-padding, padding:-padding, -1].at[state.pot_pos].get() > POT_FULL_STATUS).sum() + num_pots_not_started = ( + (maze_map[padding:-padding, padding:-padding, -1] > POT_FULL_STATUS) * state.pot_pos).sum() + num_pots_ready = num_pots - num_pots_cooking - num_pots_not_started + pot_left_over_for_plate = (num_pots_cooking + num_pots_ready - + 1 * (other_inventory == OBJECT_TO_INDEX["dish"])) > 0 + # As in orignal work: adding onion 3, getting a bowl while cooking 5, pickung up a soup 5 + shaped_reward_c1 = (new_object_in_inv == OBJECT_TO_INDEX["empty"]) * ( + object_in_inv == OBJECT_TO_INDEX["onion"]) * case_1 * 3.0 + shaped_reward_c2 = (new_object_in_inv == OBJECT_TO_INDEX["plate"]) * (object_on_table == OBJECT_TO_INDEX["plate_pile"]) * \ + successful_pickup * no_plate_on_counter * pot_left_over_for_plate * 5.0 + shaped_reward_c3 = (new_object_in_inv == OBJECT_TO_INDEX["dish"]) * ( + object_in_inv == OBJECT_TO_INDEX["plate"]) * case_2 * 5.0 + + # jax.debug.print("no_plate {a}: {s}", a=no_plate_on_counter, s=shaped_reward_c2) + shaped_reward = shaped_reward_c1 + shaped_reward_c2 + shaped_reward_c3 + return maze_map, inventory, reward, shaped_reward + + def is_terminal(self, state: EnvState) -> bool: + """Check whether state is terminal.""" + done_steps = state.time >= self.max_steps + return done_steps | state.terminal + + def get_eval_solved_rate_fn(self): + def _fn(ep_stats): + return ep_stats['return'] > 20 # More than one soup delivered + + return _fn + + @property + def name(self) -> str: + """Environment name.""" + return "Overcooked" + + @property + def num_actions(self) -> int: + """Number of actions possible in environment.""" + return len(self.action_set) + + def action_space(self, agent_id="") -> spaces.Discrete: + """Action space of the environment. Agent_id not used since action_space is uniform for all agents""" + return spaces.Discrete( + len(self.action_set), + dtype=jnp.uint8 + ) + + def observation_space(self) -> spaces.Box: + """Observation space of the environment.""" + return spaces.Box(0, 255, self.obs_shape) + + def max_episode_steps(self) -> int: + return self.params.max_episode_steps + + def set_env_instance( + self, + encoding: EnvInstance): + """ + Instance is encoded as a PyTree containing the following fields: + agent_pos, agent_dir, goal_pos, wall_map + """ + params = self.params + agent_pos = encoding.agent_pos + agent_dir_idx = encoding.agent_dir_idx + h, w = encoding.wall_map.shape + agent_dir = DIR_TO_VEC.at[agent_dir_idx].get() + goal_pos = encoding.goal_pos + wall_map = encoding.wall_map + agent_inv = encoding.agent_inv + pot_pos = encoding.pot_pos + + onion_pile_pos = encoding.onion_pile_pos + plate_pile_pos = encoding.plate_pile_pos + + onion_pos = jnp.zeros((h, w), dtype=jnp.uint8) + plate_pos = jnp.zeros((h, w), dtype=jnp.uint8) + dish_pos = jnp.zeros((h, w), dtype=jnp.uint8) + + pot_status = jnp.ones( + (encoding.wall_map.reshape(-1).shape), dtype=jnp.uint8) * 23 + + maze_map = make_overcooked_map( + wall_map, + goal_pos, + agent_pos, + agent_dir_idx, + plate_pile_pos, + onion_pile_pos, + pot_pos, + pot_status, + onion_pos, + plate_pos, + dish_pos, + pad_obs=True, + num_agents=2, + agent_view_size=5) + + state = EnvState( + agent_pos=agent_pos, + agent_dir=agent_dir, + agent_dir_idx=agent_dir_idx, + goal_pos=goal_pos, + wall_map=wall_map, + maze_map=maze_map, + bowl_pile_pos=plate_pile_pos, + onion_pile_pos=onion_pile_pos, + agent_inv=agent_inv, + pot_pos=pot_pos, + time=0, + terminal=False + ) + + return self.get_obs(state), state + + def get_env_metrics(self, state: EnvState) -> dict: + n_walls = state.wall_map.sum() + return dict( + n_walls=n_walls, + ) + + def state_space(self) -> spaces.Dict: + """EnvState space of the environment.""" + h = self.height + w = self.width + agent_view_size = self.agent_view_size + return spaces.Dict({ + "agent_pos": spaces.Box(0, max(w, h), (2,), dtype=jnp.uint32), + "agent_dir": spaces.Discrete(4), + "goal_pos": spaces.Box(0, max(w, h), (2,), dtype=jnp.uint32), + "maze_map": spaces.Box(0, 255, (w + agent_view_size, h + agent_view_size, 3), dtype=jnp.uint32), + "time": spaces.Discrete(self.max_steps), + "terminal": spaces.Discrete(2), + }) + + def max_steps(self) -> int: + return self.max_steps + + def get_monitored_metrics(self): + return ('reward', 'shaped_reward', 'shaped_reward_scaled_by_shaped_reward_coeff', 'reward_p_shaped_reward_scaled') + + @property + def default_params(self) -> EnvParams: + # Default environment parameters + return EnvParams() + + +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +register(env_id='Overcooked', entry_point=module_path + ':Overcooked') + +if __name__ == '__main__': + from minimax.envs.wrappers import MonitorReturnWrapper + + render = False + n_envs = 1 + + kwargs = dict( + # max_episode_steps=400, + height=6, + width=9, + n_walls=15, + agent_view_size=5, + fix_to_single_layout="coord_ring_6_9" + ) + env = MonitorReturnWrapper(Overcooked(**kwargs)) + params = env.params + extra = env.reset_extra() + + jit_reset_env = env.reset + jit_step_env = env.step + + key = jax.random.PRNGKey(0) + key, subkey = jax.random.split(key) + obs, state, extra = jit_reset_env(subkey) + + all_sps = [] + + import time + for ac in [0, 0, 5, 0, 0]: # [1, 1, 3, 1, 5]: + key, subkey = jax.random.split(key) + # vrngs = jax.random.split(subkey) + start = time.time() + jax.debug.print('obs:\n{a}', a=(obs['agent_0'][:, :, 0] + * 1 + obs['agent_0'][:, :, 1]*2+obs['agent_0'][:, :, 11]*3).T) + obs, state, reward, done, info, extra = jit_step_env( + subkey, + state, + action={ + 'agent_0': ac, + 'agent_1': ac + }, + extra=extra + ) + jax.debug.print("reward r {r} {ir} {isr}", r=reward, + ir=info["sparse_reward"], isr=info["shaped_reward"]) + + state = state.replace(agent_inv=jnp.array( + [OBJECT_TO_INDEX['onion'], OBJECT_TO_INDEX['onion']])) + + obs['agent_0'].block_until_ready() + end = time.time() + # print(f"sps: {1/(end-start) * n_envs}") + # print('return:', info['return']) + all_sps.append(1/(end-start) * n_envs) + + print('mean sps:', np.mean(all_sps)) + print('std sps:', np.std(all_sps)) diff --git a/src/minimax/envs/overcooked_proc/overcooked_comparators.py b/src/minimax/envs/overcooked_proc/overcooked_comparators.py new file mode 100644 index 0000000..c2444bf --- /dev/null +++ b/src/minimax/envs/overcooked_proc/overcooked_comparators.py @@ -0,0 +1,40 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import jax +import jax.numpy as jnp + +from minimax.envs.registration import register_comparator + + +@jax.jit +def is_equal_map(a, b): + agent_pos_eq = jnp.equal(a.agent_pos, b.agent_pos).all() + goal_pos_eq = jnp.equal(a.goal_pos, b.goal_pos).all() + wall_map_eq = jnp.equal(a.wall_map, b.wall_map).all() + pot_pos_eq = jnp.equal(a.pot_pos, b.pot_pos).all() + onion_pos_eq = jnp.equal(a.onion_pile_pos, b.onion_pile_pos).all() + bowl_pos_eq = jnp.equal(a.bowl_pile_pos, b.bowl_pile_pos).all() + + _eq = jnp.logical_and(agent_pos_eq, goal_pos_eq) + _eq = jnp.logical_and(_eq, pot_pos_eq) + _eq = jnp.logical_and(_eq, onion_pos_eq) + _eq = jnp.logical_and(_eq, bowl_pos_eq) + _eq = jnp.logical_and(_eq, wall_map_eq) + + return _eq + + +# Register the mutators +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +register_comparator(env_id='Overcooked', comparator_id=None, + entry_point=module_path + ':is_equal_map') diff --git a/src/minimax/envs/overcooked_proc/overcooked_mutators.py b/src/minimax/envs/overcooked_proc/overcooked_mutators.py new file mode 100644 index 0000000..9612304 --- /dev/null +++ b/src/minimax/envs/overcooked_proc/overcooked_mutators.py @@ -0,0 +1,253 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from enum import IntEnum + +import numpy as np +import jax +import jax.numpy as jnp + +from .common import make_overcooked_map +from minimax.envs.registration import register_mutator + + +class Mutations(IntEnum): + # Turn left, turn right, move forward + NO_OP = 0 + FLIP_WALL = 1 + MOVE_GOAL = 2 + MOVE_POT = 3 + MOVE_ONION_PILE = 4 + MOVE_PLATE_PILE = 5 + + +def add_1_or_2_items(key, all_pos, legal_pick_mask): + obj_mask = jnp.zeros_like(all_pos, dtype=jnp.uint8) + key, subkey = jax.random.split(key) + item_idx_1 = jax.random.choice(subkey, all_pos, shape=( + 1,), p=(legal_pick_mask.astype(jnp.bool_)).astype(jnp.uint8)) + + key, subkey = jax.random.split(key) + and_2 = jax.random.bernoulli(subkey, 0.5) + + key, subkey = jax.random.split(key) + item_idx_2 = jax.random.choice(subkey, all_pos, shape=( + 1,), p=(legal_pick_mask.astype(jnp.bool_)).astype(jnp.uint8)) + + obj_mask = obj_mask.at[item_idx_1].set(1) + + update_2 = jnp.logical_or( + obj_mask.at[item_idx_2].get(), and_2.astype(jnp.uint8)) + obj_mask = obj_mask.at[item_idx_2].set(update_2) + return obj_mask + + +def flip_wall(rng, state): + wall_map = state.wall_map + h, w = wall_map.shape + wall_mask = jnp.ones((h*w,), dtype=jnp.bool_) + + goal_idx = state.goal_pos.flatten() + agent_idx = state.agent_pos.flatten() + pot_pos_idx = state.pot_pos.flatten() + onion_pile_pos_idx = state.onion_pile_pos.flatten() + plate_pile_pos_idx = state.bowl_pile_pos.flatten() + + # Do not flip wall below an object or agent + wall_mask = wall_mask.at[goal_idx].set(False) + wall_mask = wall_mask.at[agent_idx].set(False) + wall_mask = wall_mask.at[pot_pos_idx].set(False) + wall_mask = wall_mask.at[onion_pile_pos_idx].set(False) + wall_mask = wall_mask.at[plate_pile_pos_idx].set(False) + + # Never allowed to flip a edge in overcooked + wall_mask = wall_mask.reshape(h, w) + wall_mask = wall_mask.at[:, 0].set(False) + wall_mask = wall_mask.at[:, -1].set(False) + wall_mask = wall_mask.at[0, :].set(False) + wall_mask = wall_mask.at[-1, :].set(False) + wall_mask = wall_mask.flatten() + + flip_idx = jax.random.choice(rng, np.arange(h*w), shape=(), p=wall_mask) + + wall_map = wall_map.flatten() + flip_val = ~wall_map.at[flip_idx].get() + + wall_map = wall_map.at[flip_idx].set(flip_val) + next_wall_map = wall_map.reshape(state.wall_map.shape) + return state.replace(wall_map=next_wall_map) + + +def move_goal(rng, state): + wall_map = state.wall_map + h, w = wall_map.shape + wall_mask = wall_map.flatten() + + onion_pile_pos_idx = state.onion_pile_pos.flatten() + bowl_pile_pos_idx = state.bowl_pile_pos.flatten() + goal_idx = state.goal_pos.flatten() + pot_idx = state.pot_pos.flatten() + + # No previous position and other objects + wall_mask = wall_mask.at[goal_idx].set(False) + wall_mask = wall_mask.at[pot_idx].set(False) + wall_mask = wall_mask.at[bowl_pile_pos_idx].set(False) + wall_mask = wall_mask.at[onion_pile_pos_idx].set(False) + + # Move around the wall + all_pos = jnp.zeros((h*w,), dtype=jnp.uint8) + next_goal_pos = add_1_or_2_items(rng, all_pos, wall_mask) + return state.replace(goal_pos=next_goal_pos.reshape(state.goal_pos.shape)) + + +def move_pot(rng, state): + wall_map = state.wall_map + h, w = wall_map.shape + wall_mask = wall_map.flatten() + + onion_pile_pos_idx = state.onion_pile_pos.flatten() + bowl_pile_pos_idx = state.bowl_pile_pos.flatten() + goal_idx = state.goal_pos.flatten() + pot_idx = state.pot_pos.flatten() + + # No previous position and other objects + wall_mask = wall_mask.at[goal_idx].set(False) + wall_mask = wall_mask.at[pot_idx].set(False) + wall_mask = wall_mask.at[bowl_pile_pos_idx].set(False) + wall_mask = wall_mask.at[onion_pile_pos_idx].set(False) + + # Move around the wall + all_pos = jnp.zeros((h*w,), dtype=jnp.uint8) + next_pot_pos = add_1_or_2_items(rng, all_pos, wall_mask) + return state.replace(pot_pos=next_pot_pos.reshape(state.pot_pos.shape)) + + +def move_onion_pile(rng, state): + wall_map = state.wall_map + h, w = wall_map.shape + wall_mask = wall_map.flatten() + + onion_pile_pos_idx = state.onion_pile_pos.flatten() + bowl_pile_pos_idx = state.bowl_pile_pos.flatten() + goal_idx = state.goal_pos.flatten() + pot_idx = state.pot_pos.flatten() + + # No previous position and other objects + wall_mask = wall_mask.at[goal_idx].set(False) + wall_mask = wall_mask.at[pot_idx].set(False) + wall_mask = wall_mask.at[bowl_pile_pos_idx].set(False) + wall_mask = wall_mask.at[onion_pile_pos_idx].set(False) + + # Move around the wall + all_pos = jnp.zeros((h*w,), dtype=jnp.uint8) + next_onion_pile_pos = add_1_or_2_items(rng, all_pos, wall_mask) + return state.replace(onion_pile_pos=next_onion_pile_pos.reshape(state.onion_pile_pos.shape)) + + +def move_bowl_pile(rng, state): + wall_map = state.wall_map + h, w = wall_map.shape + wall_mask = wall_map.flatten() + + onion_pile_pos_idx = state.onion_pile_pos.flatten() + bowl_pile_pos_idx = state.bowl_pile_pos.flatten() + goal_idx = state.goal_pos.flatten() + pot_idx = state.pot_pos.flatten() + + # No previous position and other objects + wall_mask = wall_mask.at[goal_idx].set(False) + wall_mask = wall_mask.at[pot_idx].set(False) + wall_mask = wall_mask.at[bowl_pile_pos_idx].set(False) + wall_mask = wall_mask.at[onion_pile_pos_idx].set(False) + + # Move around the wall + all_pos = jnp.zeros((h*w,), dtype=jnp.uint8) + next_plate_pile_pos = add_1_or_2_items(rng, all_pos, wall_mask) + return state.replace(bowl_pile_pos=next_plate_pile_pos.reshape(state.bowl_pile_pos.shape)) + + +@partial(jax.jit, static_argnums=(1, 3)) +def move_goal_flip_walls(rng, params, state, n=1): + if n == 0: + return state + + def _mutate(carry, step): + state = carry + rng, mutation = step + + rng, arng, brng, crng, drng, erng = jax.random.split(rng, 6) + + is_flip_wall = jnp.equal(mutation, Mutations.FLIP_WALL.value) + mutated_state = flip_wall(arng, state) + next_state = jax.tree_map(lambda x, y: jax.lax.select( + is_flip_wall, x, y), mutated_state, state) + + is_move_goal = jnp.equal(mutation, Mutations.MOVE_GOAL.value) + mutated_state = move_goal(brng, state) + next_state = jax.tree_map(lambda x, y: jax.lax.select( + is_move_goal, x, y), mutated_state, next_state) + + is_move_pot = jnp.equal(mutation, Mutations.MOVE_POT.value) + mutated_state = move_pot(crng, state) + next_state = jax.tree_map(lambda x, y: jax.lax.select( + is_move_pot, x, y), mutated_state, next_state) + + is_move_onion_pile = jnp.equal( + mutation, Mutations.MOVE_ONION_PILE.value) + mutated_state = move_onion_pile(drng, state) + next_state = jax.tree_map(lambda x, y: jax.lax.select( + is_move_onion_pile, x, y), mutated_state, next_state) + + is_move_plate_pile = jnp.equal( + mutation, Mutations.MOVE_PLATE_PILE.value) + mutated_state = move_bowl_pile(erng, state) + next_state = jax.tree_map(lambda x, y: jax.lax.select( + is_move_plate_pile, x, y), mutated_state, next_state) + + return next_state, None + + rng, nrng, *mrngs = jax.random.split(rng, n+2) + mutations = jax.random.choice(nrng, np.arange(len(Mutations)), (n,)) + + state, _ = jax.lax.scan(_mutate, state, (jnp.array(mrngs), mutations)) + + onion_pos = jnp.zeros(state.wall_map.shape, dtype=jnp.uint8) + plate_pos = jnp.zeros(state.wall_map.shape, dtype=jnp.uint8) + dish_pos = jnp.zeros(state.wall_map.shape, dtype=jnp.uint8) + + pot_status = jnp.ones((state.wall_map.reshape(-1).shape), dtype=jnp.uint8) * 23 + + next_maze_map = make_overcooked_map( + state.wall_map, + state.goal_pos, + state.agent_pos, + state.agent_dir_idx, + state.bowl_pile_pos, + state.onion_pile_pos, + state.pot_pos, + pot_status, + onion_pos, + plate_pos, + dish_pos, + pad_obs=True, + num_agents=2, + agent_view_size=5 + ) + + return state.replace(maze_map=next_maze_map) + + +# Register the mutators +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +register_mutator(env_id='Overcooked', mutator_id=None, + entry_point=module_path + ':move_goal_flip_walls') diff --git a/src/minimax/envs/overcooked_proc/overcooked_ood.py b/src/minimax/envs/overcooked_proc/overcooked_ood.py new file mode 100644 index 0000000..e600899 --- /dev/null +++ b/src/minimax/envs/overcooked_proc/overcooked_ood.py @@ -0,0 +1,405 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Tuple + +import jax +import jax.numpy as jnp +import chex + +from flax.core.frozen_dict import FrozenDict + +from minimax.envs.registration import register +from minimax.envs.overcooked_proc.layouts import layout_grid_to_onehot_dict +from .common import ( + OBJECT_TO_INDEX, + make_overcooked_map, +) +from .overcooked import ( + Overcooked, + EnvParams, + EnvState, + _obtain_from_layout, + +) + +cramped_room = { + "height": 4, + "width": 5, + "wall_idx": jnp.array([0, 1, 2, 3, 4, + 5, 9, + 10, 14, + 15, 16, 17, 18, 19]), + "agent_idx": jnp.array([6, 8]), + "goal_idx": jnp.array([18]), + "plate_pile_idx": jnp.array([16]), + "onion_pile_idx": jnp.array([5, 9]), + "pot_idx": jnp.array([2]) +} +asymm_advantages = { + "height": 5, + "width": 9, + "wall_idx": jnp.array([0, 1, 2, 3, 4, 5, 6, 7, 8, + 9, 11, 12, 13, 14, 15, 17, + 18, 22, 26, + 27, 31, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44]), + "agent_idx": jnp.array([29, 32]), + "goal_idx": jnp.array([12, 17]), + "plate_pile_idx": jnp.array([39, 41]), + "onion_pile_idx": jnp.array([9, 14]), + "pot_idx": jnp.array([22, 31]) +} +coord_ring = { + "height": 5, + "width": 5, + "wall_idx": jnp.array([0, 1, 2, 3, 4, + 5, 9, + 10, 12, 14, + 15, 19, + 20, 21, 22, 23, 24]), + "agent_idx": jnp.array([7, 11]), + "goal_idx": jnp.array([22]), + "plate_pile_idx": jnp.array([10]), + "onion_pile_idx": jnp.array([15, 21]), + "pot_idx": jnp.array([3, 9]) +} +forced_coord = { + "height": 5, + "width": 5, + "wall_idx": jnp.array([0, 1, 2, 3, 4, + 5, 7, 9, + 10, 12, 14, + 15, 17, 19, + 20, 21, 22, 23, 24]), + "agent_idx": jnp.array([11, 8]), + "goal_idx": jnp.array([23]), + "onion_pile_idx": jnp.array([5, 10]), + "plate_pile_idx": jnp.array([15]), + "pot_idx": jnp.array([3, 9]) +} + +# Example of layout provided as a grid +counter_circuit_grid = """ +WWWPPWWW +W A W +B WWWW X +W AW +WWWOOWWW +""" + +asymm_advantages_6_9 = """ +WWWWWWWWW +O WXWOW X +W P A W +WA P W +WWWBWBWWW +WWWWWWWWW +""" + +counter_circuit_6_9 = """ +WWWPPWWWW +W A WW +B WWWW XW +W AWW +WWWOOWWWW +WWWWWWWWW +""" + +forced_coord_6_9 = """ +WWWPWWWWW +OAWAPWWWW +O W WWWWW +B W WWWWW +WWWXWWWWW +WWWWWWWWW +""" + +cramped_room_6_9 = """ +WWPWWWWWW +OAA OWWWW +W WWWWW +WBWXWWWWW +WWWWWWWWW +WWWWWWWWW +""" + +coord_ring_6_9 = """ +WWWPWWWWW +WA APWWWW +B W WWWWW +O WWWWW +WOXWWWWWW +WWWWWWWWW +""" + +forced_coord_5_5 = """ +WWWPW +OAWAP +O W W +B W W +WWWXW +""" + +cramped_room_5_5 = """ +WWPWW +OAA O +W W +WBWXW +WWWWW +""" + +coord_ring_5_5 = """ +WWWPW +WA AP +B W W +O W +WOXWW +""" + + +# ======== Singleton mazes ======== +class OvercookedSingleton(Overcooked): + def __init__( + self, + grid, + agent_view_size=5, + replace_wall_pos=False, + max_steps=400, + normalize_obs=False, + sample_n_walls=False, + singleton_seed=-1 + ): + height = grid["height"] + width = grid["width"] + super().__init__( + height=height, + width=width, + agent_view_size=agent_view_size, + replace_wall_pos=replace_wall_pos and not sample_n_walls, + max_steps=max_steps, + normalize_obs=normalize_obs, + sample_n_walls=sample_n_walls, + singleton_seed=singleton_seed, + ) + + self.params = EnvParams( + height=height, + width=width, + agent_view_size=agent_view_size, + normalize_obs=normalize_obs, + max_steps=max_steps, + singleton_seed=singleton_seed, + ) + + h = self.height + w = self.width + + # NOTE: that since the layout is fixed, the random_reset is set to False + # and this is why jax.random.PRNGKey(0) is used too (not needed if no random_reset). + wall_map, goal_pos, agent_pos, agent_dir, agent_dir_idx, plate_pile_pos, onion_pile_pos, pot_pos, pot_status\ + = _obtain_from_layout(jax.random.PRNGKey(0), grid, h, w, random_reset=False, num_agents=2) + + onion_pos = jnp.zeros((h, w), dtype=jnp.uint8) + plate_pos = jnp.zeros((h, w), dtype=jnp.uint8) + dish_pos = jnp.zeros((h, w), dtype=jnp.uint8) + + agent_inv = jnp.array( + [OBJECT_TO_INDEX['empty'], OBJECT_TO_INDEX['empty']]) + + self.overcooked_map = make_overcooked_map( + wall_map, + goal_pos, + agent_pos, + agent_dir_idx, + plate_pile_pos, + onion_pile_pos, + pot_pos, + pot_status, + onion_pos, + plate_pos, + dish_pos, + pad_obs=True, + num_agents=self.num_agents, + agent_view_size=self.agent_view_size) + + self.agent_pos = agent_pos + self.agent_dir = agent_dir + self.agent_dir_idx = agent_dir_idx + self.agent_inv = agent_inv + self.goal_pos = goal_pos + self.pot_pos = pot_pos + self.bowl_pile_pos = plate_pile_pos + self.onion_pile_pos = onion_pile_pos + self.wall_map = wall_map + + @property + def default_params(self) -> EnvParams: + # Default environment parameters + return EnvParams() + + def reset_env( + self, + key: chex.PRNGKey, + ) -> Tuple[chex.Array, EnvState]: + + state = EnvState( + agent_pos=self.agent_pos, + agent_dir=self.agent_dir, + agent_dir_idx=self.agent_dir_idx, + agent_inv=self.agent_inv, + goal_pos=self.goal_pos, + pot_pos=self.pot_pos, + wall_map=self.wall_map.astype(jnp.bool_), + maze_map=self.overcooked_map, + bowl_pile_pos=self.bowl_pile_pos, + onion_pile_pos=self.onion_pile_pos, + time=0, + terminal=False, + ) + + return self.get_obs(state), state + + +# ======== Specific mazes ======== +class CoordRing6_9(OvercookedSingleton): + def __init__( + self, + normalize_obs=False): + self.layout_name = "coord_ring_6_9" + + grid = layout_grid_to_onehot_dict(coord_ring_6_9) + + super().__init__( + grid=grid, + normalize_obs=normalize_obs, + ) + + +class ForcedCoord6_9(OvercookedSingleton): + def __init__( + self, + normalize_obs=False): + self.layout_name = "forced_coord_6_9" + + grid = layout_grid_to_onehot_dict(forced_coord_6_9) + + super().__init__( + grid=grid, + normalize_obs=normalize_obs, + ) + + +class CounterCircuit6_9(OvercookedSingleton): + def __init__( + self, + normalize_obs=False): + self.layout_name = "counter_circuit_6_9" + + grid = layout_grid_to_onehot_dict(counter_circuit_6_9) + + super().__init__( + grid=grid, + normalize_obs=normalize_obs, + ) + + +class AsymmAdvantages6_9(OvercookedSingleton): + def __init__( + self, + normalize_obs=False): + self.layout_name = "asymm_advantages_6_9" + + grid = layout_grid_to_onehot_dict(asymm_advantages_6_9) + + super().__init__( + grid=grid, + normalize_obs=normalize_obs, + ) + + +class CrampedRoom6_9(OvercookedSingleton): + def __init__( + self, + normalize_obs=False): + self.layout_name = "cramped_room_6_9" + + grid = layout_grid_to_onehot_dict(cramped_room_6_9) + + super().__init__( + grid=grid, + normalize_obs=normalize_obs, + ) + + +class CoordRing5_5(OvercookedSingleton): + def __init__( + self, + normalize_obs=False): + self.layout_name = "coord_ring_5_5" + + grid = layout_grid_to_onehot_dict(coord_ring_5_5) + + super().__init__( + grid=grid, + normalize_obs=normalize_obs, + ) + + +class ForcedCoord5_5(OvercookedSingleton): + def __init__( + self, + normalize_obs=False): + self.layout_name = "forced_coord_5_5" + + grid = layout_grid_to_onehot_dict(forced_coord_5_5) + + super().__init__( + grid=grid, + normalize_obs=normalize_obs, + ) + + +class CrampedRoom5_5(OvercookedSingleton): + def __init__( + self, + normalize_obs=False): + self.layout_name = "cramped_room_5_5" + + grid = layout_grid_to_onehot_dict(cramped_room_5_5) + + super().__init__( + grid=grid, + normalize_obs=normalize_obs, + ) + + +# ======== Registration ======== +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +# register(env_id='Overcooked', entry_point=module_path + ':') +register(env_id='Overcooked-CoordRing6_9', + entry_point=module_path + ':CoordRing6_9') +register(env_id='Overcooked-ForcedCoord6_9', + entry_point=module_path + ':ForcedCoord6_9') +register(env_id='Overcooked-CounterCircuit6_9', + entry_point=module_path + ':CounterCircuit6_9') +register(env_id='Overcooked-AsymmAdvantages6_9', + entry_point=module_path + ':AsymmAdvantages6_9') +register(env_id='Overcooked-CrampedRoom6_9', + entry_point=module_path + ':CrampedRoom6_9') + +register(env_id='Overcooked-CoordRing5_5', + entry_point=module_path + ':CoordRing5_5') +register(env_id='Overcooked-ForcedCoord5_5', + entry_point=module_path + ':ForcedCoord5_5') +register(env_id='Overcooked-CrampedRoom5_5', + entry_point=module_path + ':CrampedRoom5_5') diff --git a/src/minimax/envs/overcooked_proc/overcooked_ued.py b/src/minimax/envs/overcooked_proc/overcooked_ued.py new file mode 100644 index 0000000..f33da24 --- /dev/null +++ b/src/minimax/envs/overcooked_proc/overcooked_ued.py @@ -0,0 +1,541 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections import OrderedDict +from enum import IntEnum + +import jax +import jax.numpy as jnp +from jax import lax +from typing import Dict, Tuple +import chex +from flax import struct + +from minimax.envs.overcooked_proc.overcooked import DIR_TO_VEC, EnvState + +from .common import OBJECT_TO_INDEX, EnvInstance, make_overcooked_map +from minimax.envs import environment, spaces +from minimax.envs.registration import register_ued + + +class SequentialActions(IntEnum): + skip = 0 + wall = 1 + goal = 2 + agent = 3 + onion = 4 + soup = 5 + bowls = 6 + + +@struct.dataclass +class UEDEnvState: + encoding: chex.Array + time: int + terminal: bool + + +@struct.dataclass +class EnvParams: + height: int = 6 + width: int = 9 + n_walls: int = 25 + noise_dim: int = 50 + agent_view_size: int = 5 + replace_wall_pos: bool = False + fixed_n_wall_steps: bool = False + first_wall_pos_sets_budget: bool = False + use_seq_actions: bool = False + normalize_obs: bool = False + sample_n_walls: bool = False # Sample n_walls uniformly in [0, n_walls] + max_steps: int = 400 + singleton_seed: int = -1 + max_episode_steps: int = 400 + + +class UEDOvercooked(environment.Environment): + def __init__( + self, + height=6, + width=9, + n_walls=25, + noise_dim=16, + replace_wall_pos=False, + fixed_n_wall_steps=False, + first_wall_pos_sets_budget=False, + use_seq_actions=False, + normalize_obs=False, + ): + """ + Using the original action space requires ensuring proper handling + of a sequence with trailing dones, e.g. dones: 0 0 0 0 1 1 1 1 1 ... 1. + Advantages and value losses should only be computed where ~dones[0]. + """ + assert not (first_wall_pos_sets_budget and fixed_n_wall_steps), \ + 'Setting first_wall_pos_sets_budget=True requires fixed_n_wall_steps=False.' + + super().__init__() + + self.n_tiles = height*width + # go straight, turn left, turn right, take action + self.action_set = jnp.array(jnp.arange(self.n_tiles)) + + self.agents = ["agent_0", "agent_1"] + + self.params = EnvParams( + height=height, + width=width, + n_walls=n_walls, + noise_dim=noise_dim, + replace_wall_pos=replace_wall_pos, + fixed_n_wall_steps=fixed_n_wall_steps, + first_wall_pos_sets_budget=first_wall_pos_sets_budget, + use_seq_actions=False, + normalize_obs=normalize_obs, + ) + + @staticmethod + def align_kwargs(kwargs, other_kwargs): + kwargs.update(dict( + height=other_kwargs['height'], + width=other_kwargs['width'], + )) + + return kwargs + + def _add_noise_to_obs(self, rng, obs): + if self.params.noise_dim > 0: + noise = jax.random.uniform(rng, (self.params.noise_dim,)) + obs.update(dict(noise=noise)) + + return obs + + def reset_env( + self, + key: chex.PRNGKey): + """ + Prepares the environment state for a new design + from a blank slate. + """ + params = self.params + noise_rng, dir_rng = jax.random.split(key) + encoding = jnp.zeros((self._get_encoding_dim(),), dtype=jnp.uint8) + + state = UEDEnvState( + encoding=encoding, + time=0, + terminal=False, + ) + + obs = self._add_noise_to_obs( + noise_rng, + self.get_obs(state) + ) + + return obs, state + + def get_monitored_metrics(self): + return () + + def step_env( + self, + key: chex.PRNGKey, + state: UEDEnvState, + action: int, + ) -> Tuple[chex.Array, UEDEnvState, float, bool, dict]: + """ + Take a design step. + action: A pos as an int from 0 to (height*width)-1 + """ + params = self.params + + collision_rng, noise_rng = jax.random.split(key) + + # Sample a random free tile in case of a collision + dist_values = jnp.logical_and( # True if position taken + jnp.ones(params.n_walls + 10), + jnp.arange(params.n_walls + 10)+1 > state.time + ) + + # Get zero-indexed last wall time step + if params.fixed_n_wall_steps: + max_n_walls = params.n_walls + encoding_pos = state.encoding[:params.n_walls+10] + last_wall_step_idx = max_n_walls - 1 + else: + max_n_walls = jnp.round( + params.n_walls*state.encoding[0]/self.n_tiles).astype(jnp.uint32) + + if self.params.first_wall_pos_sets_budget: + encoding_pos = state.encoding[:params.n_walls+10] + last_wall_step_idx = jnp.maximum(max_n_walls, 1) - 1 + else: + encoding_pos = state.encoding[1:params.n_walls+11] + last_wall_step_idx = max_n_walls + + pos_dist = jnp.ones(self.n_tiles).at[ + jnp.flip(encoding_pos)].set(jnp.flip(dist_values)) + all_pos = jnp.arange(self.n_tiles, dtype=jnp.uint8) + + agent_step_1_idx = last_wall_step_idx+1 # Enc is full length + agent_step_2_idx = last_wall_step_idx+2 + + # Track whether it is the last time step + next_state = state.replace(time=state.time + 1) + done = self.is_terminal(next_state) + + collision = jnp.logical_and( + pos_dist[action] < 1, + jnp.logical_or( + not params.replace_wall_pos, + jnp.logical_and( # agent pos cannot be overriden + # jnp.logical_or(), + # jnp.equal(state.time, goal_step_1_idx), + jnp.equal(state.encoding[agent_step_1_idx], action), + jnp.equal(state.encoding[agent_step_2_idx], action) + ) + ) + ) + # collision = (collision * (1-is_agent_dir_step)).astype(jnp.uint32) + + action = (1-collision)*action + \ + collision*jax.random.choice(collision_rng, + all_pos, replace=False, p=pos_dist) + + # (1-is_agent_dir_step)* # + is_agent_dir_step*(-1) + enc_idx = state.time + encoding = state.encoding.at[enc_idx].set(action) + + next_state = next_state.replace( + encoding=encoding, + terminal=done + ) + reward = 0 + + obs = self._add_noise_to_obs(noise_rng, self.get_obs(next_state)) + + # jax.debug.breakpoint() + return ( + lax.stop_gradient(obs), + lax.stop_gradient(next_state), + reward, + done, + {}, + ) + + def get_env_instance( + self, + key: chex.PRNGKey, + state: UEDEnvState + ) -> chex.Array: + """ + Converts internal encoding to an instance encoding that + can be interpreted by the `set_to_instance` method + the paired Environment class. + """ + params = self.params + h = params.height + w = params.width + enc = state.encoding + + # === Extract agent_dir, agent_pos, and goal_pos === + # Num walls placed currently + if params.fixed_n_wall_steps: + n_walls = params.n_walls + enc_len = self._get_encoding_dim() + wall_pos_idx = jnp.flip(enc[:params.n_walls]) + agent_pos_1_idx = enc_len-2 # Enc is full length + agent_pos_2_idx = enc_len-3 + goal_pos_1_idx = enc_len-4 + onion_pos_1_idx = enc_len-6 + pot_pos_1_idx = enc_len-8 + bowl_pos_1_idx = enc_len-10 + else: + n_walls = jnp.round( + params.n_walls*enc[0]/self.n_tiles + ).astype(jnp.uint32) + if params.first_wall_pos_sets_budget: + # So 0-padding does not override pos=0 + wall_pos_idx = jnp.flip(enc[:params.n_walls]) + enc_len = n_walls + 10 # [wall_pos] + len((goal, agent)) + else: + wall_pos_idx = jnp.flip(enc[1:params.n_walls+1]) + # [wall_pos] + len((n_walls, goal, agent)) + enc_len = n_walls + 11 + agent_pos_1_idx = enc_len-1 # Enc is full length + agent_pos_2_idx = enc_len-2 + goal_pos_1_idx = enc_len-3 + onion_pos_1_idx = enc_len-5 + pot_pos_1_idx = enc_len-7 + bowl_pos_1_idx = enc_len-9 + + # Make wall map + wall_start_time = jnp.logical_and( # 1 if explicitly predict # blocks, else 0 + not params.fixed_n_wall_steps, + not params.first_wall_pos_sets_budget + ).astype(jnp.uint32) + + wall_map = jnp.zeros((h * w), dtype=jnp.bool_) + wall_values = jnp.arange( + params.n_walls) + wall_start_time < jnp.minimum(state.time, n_walls + wall_start_time) + wall_values = jnp.flip(wall_values) + wall_map = wall_map.at[wall_pos_idx].set(wall_values) + wall_map = wall_map.reshape((h, w)) + wall_map = wall_map.at[0, :].set(True) + wall_map = wall_map.at[:, 0].set(True) + wall_map = wall_map.at[-1, :].set(True) + wall_map = wall_map.at[:, -1].set(True) + wall_map = wall_map.reshape(-1) + + occupied_mask = wall_map + + """Agents should always end up on an empty square. If they are placed on a wall pick randomly.""" + is_occupied = occupied_mask[enc[agent_pos_1_idx]] == 1 + agent_pos_1_idx_enc = is_occupied*jax.random.choice(key, jnp.arange(h*w), shape=( + ), p=jnp.logical_not(occupied_mask)) + jnp.logical_not(is_occupied)*enc[agent_pos_1_idx] + agent_1_placed = state.time > jnp.array( + [agent_pos_1_idx], dtype=jnp.uint8) + agent_1_pos = \ + agent_1_placed*jnp.array([agent_pos_1_idx_enc % w, agent_pos_1_idx_enc // w], dtype=jnp.uint8) \ + + (~agent_1_placed)*jnp.array([h, w], dtype=jnp.uint8) + occupied_mask = occupied_mask.at[agent_pos_1_idx_enc].set(True) + + is_occupied = occupied_mask[enc[agent_pos_2_idx]] == 1 + agent_pos_2_idx_enc = is_occupied*jax.random.choice(key, jnp.arange( + h*w), shape=(), p=jnp.logical_not(occupied_mask)) + jnp.logical_not(is_occupied)*enc[agent_pos_2_idx] + agent_2_placed = state.time > jnp.array( + [agent_pos_2_idx], dtype=jnp.uint8) + agent_2_pos = \ + agent_2_placed*jnp.array([agent_pos_2_idx_enc % w, agent_pos_2_idx_enc // w], dtype=jnp.uint8) \ + + (~agent_2_placed)*jnp.array([h, w], dtype=jnp.uint8) + occupied_mask = occupied_mask.at[agent_pos_2_idx_enc].set(True) + + agents_obj_occupied_mask = jnp.zeros_like(occupied_mask) + agents_obj_occupied_mask = agents_obj_occupied_mask.reshape((h, w)) + # Exlude corners, will never be actually reachable + agents_obj_occupied_mask = agents_obj_occupied_mask.at[0, 0].set(True) + agents_obj_occupied_mask = agents_obj_occupied_mask.at[0, -1].set(True) + agents_obj_occupied_mask = agents_obj_occupied_mask.at[-1, 0].set(True) + agents_obj_occupied_mask = agents_obj_occupied_mask.at[-1, -1].set( + True) + agents_obj_occupied_mask = agents_obj_occupied_mask.reshape(-1) + agents_obj_occupied_mask = agents_obj_occupied_mask.at[ + agent_pos_1_idx_enc].set(True) + agents_obj_occupied_mask = agents_obj_occupied_mask.at[ + agent_pos_2_idx_enc].set(True) + + """Objects can end up on a wall but never on a agent or another agent.""" + is_occupied = agents_obj_occupied_mask[enc[goal_pos_1_idx]] == 1 + goal_pos_1_idx_enc = is_occupied*jax.random.choice(key, jnp.arange( + h*w), shape=(), p=jnp.logical_not(agents_obj_occupied_mask)) + jnp.logical_not(is_occupied)*enc[goal_pos_1_idx] + goal_1_placed = state.time > jnp.array( + [goal_pos_1_idx], dtype=jnp.uint8) + goal_1_pos = \ + goal_1_placed*jnp.zeros((h*w), dtype=jnp.uint8).at[goal_pos_1_idx_enc].set(1) \ + + (~goal_1_placed)*jnp.zeros((h*w), dtype=jnp.uint8) + goal_1_pos = goal_1_pos.reshape((h, w)) + agents_obj_occupied_mask = agents_obj_occupied_mask.at[ + goal_pos_1_idx_enc].set(True) + wall_map = wall_map.at[goal_pos_1_idx_enc].set(True) + + is_occupied = agents_obj_occupied_mask[enc[onion_pos_1_idx]] == 1 + onion_pos_1_idx_enc = is_occupied*jax.random.choice(key, jnp.arange( + h*w), shape=(), p=jnp.logical_not(agents_obj_occupied_mask)) + jnp.logical_not(is_occupied)*enc[onion_pos_1_idx] + onion_1_placed = state.time > jnp.array( + [onion_pos_1_idx], dtype=jnp.uint8) + onion_1_pos = \ + onion_1_placed*jnp.zeros((h*w), dtype=jnp.uint8).at[onion_pos_1_idx_enc].set(1) \ + + (~onion_1_placed)*jnp.zeros((h*w), dtype=jnp.uint8) + onion_1_pos = onion_1_pos.reshape((h, w)) + agents_obj_occupied_mask = agents_obj_occupied_mask.at[ + onion_pos_1_idx_enc].set(True) + wall_map = wall_map.at[onion_pos_1_idx_enc].set(True) + + is_occupied = agents_obj_occupied_mask[enc[pot_pos_1_idx]] == 1 + pot_pos_1_idx_enc = is_occupied*jax.random.choice(key, jnp.arange( + h*w), shape=(), p=jnp.logical_not(agents_obj_occupied_mask)) + jnp.logical_not(is_occupied)*enc[pot_pos_1_idx] + pot_1_placed = state.time > jnp.array( + [pot_pos_1_idx], dtype=jnp.uint8) + pot_1_pos = \ + pot_1_placed*jnp.zeros((h*w), dtype=jnp.uint8).at[pot_pos_1_idx_enc].set(1) \ + + (~pot_1_placed)*jnp.zeros((h*w), dtype=jnp.uint8) + pot_1_pos = pot_1_pos.reshape((h, w)) + agents_obj_occupied_mask = agents_obj_occupied_mask.at[ + pot_pos_1_idx_enc].set(True) + wall_map = wall_map.at[pot_pos_1_idx_enc].set(True) + + is_occupied = agents_obj_occupied_mask[enc[bowl_pos_1_idx]] == 1 + bowl_pos_1_idx_enc = is_occupied*jax.random.choice(key, jnp.arange( + h*w), shape=(), p=jnp.logical_not(agents_obj_occupied_mask)) + jnp.logical_not(is_occupied)*enc[bowl_pos_1_idx] + bowl_1_placed = state.time > jnp.array( + [bowl_pos_1_idx], dtype=jnp.uint8) + bowl_1_pos = \ + bowl_1_placed*jnp.zeros((h*w), dtype=jnp.uint8).at[bowl_pos_1_idx_enc].set(1) \ + + (~bowl_1_placed)*jnp.zeros((h*w), dtype=jnp.uint8) + bowl_1_pos = bowl_1_pos.reshape((h, w)) + agents_obj_occupied_mask = agents_obj_occupied_mask.at[ + bowl_pos_1_idx_enc].set(True) + wall_map = wall_map.at[bowl_pos_1_idx_enc].set(True) + + # agent_dir_idx = jnp.floor((4*enc[-1]/self.n_tiles)).astype(jnp.uint8) + key, subkey = jax.random.split(key) + agent_dir_idx = jax.random.choice(subkey, jnp.arange( + len(DIR_TO_VEC), dtype=jnp.int32), shape=(2,)) + + # Zero out walls where agent and goal reside + # Should not be the case but just in case + agent_1_mask = agent_1_placed * \ + (~(jnp.arange(h*w) == agent_pos_1_idx_enc)) + ~agent_1_placed*wall_map + agent_2_mask = agent_2_placed * \ + (~(jnp.arange(h*w) == agent_pos_2_idx_enc)) + ~agent_2_placed*wall_map + goal_mask = goal_1_placed * \ + (~(jnp.arange(h*w) == goal_pos_1_idx_enc)) + ~goal_1_placed*wall_map + wall_map = wall_map*agent_1_mask*agent_2_mask + wall_map = wall_map.reshape(h, w) + + possible_items = jnp.array([OBJECT_TO_INDEX['empty'], OBJECT_TO_INDEX['onion'], + OBJECT_TO_INDEX['plate'], OBJECT_TO_INDEX['dish']]) + key, subkey = jax.random.split(key) + random_agent_inv = jax.random.choice( + subkey, possible_items, shape=(2,), replace=True) + + return EnvInstance( + agent_pos=jnp.array([agent_1_pos, agent_2_pos], dtype=jnp.uint32), + agent_dir_idx=agent_dir_idx, + goal_pos=goal_1_pos, + wall_map=wall_map, + onion_pile_pos=onion_1_pos, + pot_pos=pot_1_pos, + plate_pile_pos=bowl_1_pos, + agent_inv=random_agent_inv + ) + + def is_terminal(self, state: UEDEnvState) -> bool: + done_steps = state.time >= self.max_episode_steps() + return jnp.logical_or(done_steps, state.terminal) + + def _get_post_terminal_obs(self, state: UEDEnvState): + dtype = jnp.float32 if self.params.normalize_obs else jnp.uint8 + image = jnp.zeros(( + self.params.height+2, self.params.width+2, 3), dtype=dtype + ) + + return OrderedDict(dict( + image=image, + time=state.time, + noise=jnp.zeros(self.params.noise_dim, dtype=jnp.float32), + )) + + def get_obs(self, state: UEDEnvState): + instance = self.get_env_instance(jax.random.PRNGKey(0), state) + h = self.params.height + w = self.params.width + onion_pos = jnp.zeros((h, w), dtype=jnp.uint8) + plate_pos = jnp.zeros((h, w), dtype=jnp.uint8) + dish_pos = jnp.zeros((h, w), dtype=jnp.uint8) + + pot_status = jnp.ones( + (instance.wall_map.reshape(-1).shape), dtype=jnp.uint8) * 23 + + agent_dir = DIR_TO_VEC.at[instance.agent_dir_idx].get() + + maze_map = make_overcooked_map( + wall_map=instance.wall_map, + goal_pos=instance.goal_pos, + agent_pos=instance.agent_pos, + agent_dir_idx=instance.agent_dir_idx, + plate_pile_pos=instance.plate_pile_pos, + onion_pile_pos=instance.onion_pile_pos, + pot_pos=instance.pot_pos, + pot_status=pot_status, + onion_pos=onion_pos, + plate_pos=plate_pos, + dish_pos=dish_pos, + pad_obs=True, + num_agents=2, + agent_view_size=5 + ) + + padding = 4 + return OrderedDict(dict( + image=maze_map[padding:-padding, padding:-padding, :], + time=state.time, + )) + + @property + def default_params(self): + return EnvParams() + + @property + def name(self) -> str: + """Environment name.""" + return "UEDOvercooked" + + @property + def num_actions(self) -> int: + """Number of actions possible in environment.""" + return len(self.action_set) + + def action_space(self) -> spaces.Discrete: + """Action space of the environment.""" + params = self.params + return spaces.Discrete( + params.height*params.width, + dtype=jnp.uint8 + ) + + def observation_space(self) -> spaces.Dict: + """Observation space of the environment.""" + params = self.params + max_episode_steps = self.max_episode_steps() + spaces_dict = { + 'image': spaces.Box(0, 255, (params.height, params.width, 3)), + 'time': spaces.Discrete(max_episode_steps), + } + if self.params.noise_dim > 0: + spaces_dict.update({ + 'noise': spaces.Box(0, 1, (self.params.noise_dim,)) + }) + return spaces.Dict(spaces_dict) + + def state_space(self) -> spaces.Dict: + """State space of the environment.""" + params = self.params + encoding_dim = self._get_encoding_dim() + max_episode_steps = self.max_episode_steps() + h = params.height + w = params.width + return spaces.Dict({ + 'encoding': spaces.Box(0, 255, (encoding_dim,)), + 'time': spaces.Discrete(max_episode_steps), + "terminal": spaces.Discrete(2), + }) + + def _get_encoding_dim(self) -> int: + encoding_dim = self.max_episode_steps() + # if not self.params.set_agent_dir: + # encoding_dim += 1 # max steps is 1 less than full encoding dim + + return encoding_dim + + def max_episode_steps(self) -> int: + if self.params.fixed_n_wall_steps \ + or self.params.first_wall_pos_sets_budget: + max_episode_steps = self.params.n_walls + 10 + else: + max_episode_steps = self.params.n_walls + 11 + + return max_episode_steps + + +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +register_ued(env_id='Overcooked', entry_point=module_path + ':UEDOvercooked') diff --git a/src/minimax/envs/registration.py b/src/minimax/envs/registration.py new file mode 100644 index 0000000..be4460d --- /dev/null +++ b/src/minimax/envs/registration.py @@ -0,0 +1,149 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import importlib + +from .wrappers import * +from minimax.envs.environment_ued import UEDEnvironment + + +# Global registry +registered_envs = [] + +env2entry = {} + +env2ued_entry = {} + +env2comparator_entry = {} + +env2mutator_entry = {} + +name2wrapper = { + 'env_wrapper': EnvWrapper, # for testing, + 'world_state_wrapper': WorldStateWrapper, + 'ued_env_wrapper': UEDEnvWrapper, + 'monitor_return': MonitorReturnWrapper, + 'monitor_ep_metrics': MonitorEpisodicMetricsWrapper, +} + + +def _load(name): + mod_name, attr_name = name.split(":") + mod = importlib.import_module(mod_name) + fn = getattr(mod, attr_name) + return fn + + +def _fn_for_entry(entry): + if callable(entry): + return entry + else: + return _load(entry) + + +def cls_for_env_id(env_id): + if env_id not in env2entry.keys(): + raise ValueError(f"{env_id} is not registered.") + else: + entry = env2entry[env_id] + return _fn_for_entry(entry) + + +def _make_env(entry, **env_kwargs): + return _fn_for_entry(entry)(**env_kwargs) + + +def _apply_wrappers(env, wrappers): + base_env = env + if wrappers is not None and len(wrappers) > 0: + for name in wrappers: + wrapper_cls = name2wrapper[name] + if wrapper_cls.is_compatible(base_env): + env = wrapper_cls(env) + + return env + + +def make( + env_id: str, + env_kwargs={}, + ued_env_kwargs={}, + wrappers=None, + ued_wrappers=None +): + """The minimax equivalent of OpenAI's env.make(env_name)""" + if env_id not in env2entry.keys(): + raise ValueError(f"{env_id} is not registered.") + else: + entry = env2entry[env_id] + env = _make_env(entry, **env_kwargs) + + if len(ued_env_kwargs) > 0: + if env_id not in env2ued_entry.keys(): + raise ValueError(f"{env_id} has no UED counterpart registered.") + + _env_kwargs = env_kwargs + env_kwargs = env.default_params.__dict__ + env_kwargs.update(_env_kwargs) + + ued_entry = env2ued_entry[env_id] + ued_env_kwargs = _fn_for_entry(ued_entry).align_kwargs(ued_env_kwargs, env_kwargs) + + ued_env = _make_env(ued_entry, **ued_env_kwargs) + + env = UEDEnvironment(env=env, ued_env=ued_env) + + base_env = env + + env = _apply_wrappers(env, wrappers) + + if isinstance(base_env, UEDEnvironment): + env = _apply_wrappers(env, ued_wrappers) + return env, env.env.params, env.ued_env.params + + return env, env.params + + +def get_comparator(env_id: str, comparator_id: str = 'default'): + entry_point = env2comparator_entry[env_id].get(comparator_id, None) + assert entry_point is not None, f'Undefined comparator {comparator_id} for environment {env_id}.' + + return _fn_for_entry(entry_point) + + +def get_mutator(env_id: str, mutator_id: str = 'default'): + entry_point = env2mutator_entry[env_id].get(mutator_id, None) + assert entry_point is not None, f'Undefined mutator {mutator_id} for environment {env_id}.' + + return _fn_for_entry(entry_point) + + +def register(env_id: str, entry_point: str): + env2entry[env_id] = entry_point + + +def register_ued(env_id: str, entry_point: str): + env2ued_entry[env_id] = entry_point + + +def register_comparator(env_id: str, entry_point: str, comparator_id: str = None): + if comparator_id is None: + comparator_id = 'default' + + if env_id not in env2comparator_entry: + env2comparator_entry[env_id] = {} + env2comparator_entry[env_id][comparator_id] = entry_point + + +def register_mutator(env_id: str, entry_point: str, mutator_id: str = None): + if mutator_id is None: + mutator_id = 'default' + + if env_id not in env2mutator_entry: + env2mutator_entry[env_id] = {} + env2mutator_entry[env_id][mutator_id] = entry_point diff --git a/src/minimax/envs/spaces.py b/src/minimax/envs/spaces.py new file mode 100644 index 0000000..90f289f --- /dev/null +++ b/src/minimax/envs/spaces.py @@ -0,0 +1,154 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This file is modified from +https://github.com/RobertTLange/gymnax/ + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Tuple, Union +from collections import OrderedDict +import chex +import jax +import jax.numpy as jnp + + +class Discrete(object): + """ + Minimal jittable class for discrete gymnax spaces. + TODO: For now this is a 1d space. Make composable for multi-discrete. + """ + + def __init__(self, num_categories: int, shape: Tuple[int] = (), dtype=jnp.int32): + assert num_categories >= 0 + self.n = num_categories + self.shape = shape + self.dtype = dtype + + def sample(self, rng: chex.PRNGKey) -> chex.Array: + """Sample random action uniformly from set of categorical choices.""" + return jax.random.randint( + rng, shape=self.shape, minval=0, maxval=self.n + ).astype(self.dtype) + + def contains(self, x: jnp.int_) -> bool: + """Check whether specific object is within space.""" + # type_cond = isinstance(x, self.dtype) + # shape_cond = (x.shape == self.shape) + range_cond = jnp.logical_and(x >= 0, x < self.n) + return range_cond + + +class Box(object): + """ + Minimal jittable class for array-shaped gymnax spaces. + """ + + def __init__( + self, + low: float, + high: float, + shape: Tuple[int], + dtype: jnp.dtype = jnp.float32, + ): + self.low = low + self.high = high + self.shape = shape + self.dtype = dtype + + def sample(self, rng: chex.PRNGKey) -> chex.Array: + """Sample random action uniformly from 1D continuous range.""" + return jax.random.uniform( + rng, shape=self.shape, minval=self.low, maxval=self.high + ).astype(self.dtype) + + def contains(self, x: jnp.int_) -> bool: + """Check whether specific object is within space.""" + # type_cond = isinstance(x, self.dtype) + # shape_cond = (x.shape == self.shape) + range_cond = jnp.logical_and( + jnp.all(x >= self.low), jnp.all(x <= self.high) + ) + return range_cond + + +class Dict(object): + """Minimal jittable class for dictionary of simpler jittable spaces.""" + + def __init__(self, spaces: dict): + self.spaces = spaces + self.num_spaces = len(spaces) + + def sample(self, rng: chex.PRNGKey) -> dict: + """Sample random action from all subspaces.""" + key_split = jax.random.split(rng, self.num_spaces) + return OrderedDict( + [ + (k, self.spaces[k].sample(key_split[i])) + for i, k in enumerate(self.spaces) + ] + ) + + def contains(self, x: jnp.int_) -> bool: + """Check whether dimensions of object are within subspace.""" + # type_cond = isinstance(x, dict) + # num_space_cond = len(x) != len(self.spaces) + # Check for each space individually + out_of_space = 0 + for k, space in self.spaces.items(): + out_of_space += 1 - space.contains(getattr(x, k)) + return out_of_space == 0 + + +class Tuple(object): + """Minimal jittable class for tuple (product) of jittable spaces.""" + + def __init__(self, spaces: Union[tuple, list]): + self.spaces = spaces + self.num_spaces = len(spaces) + + def sample(self, rng: chex.PRNGKey) -> Tuple[chex.Array]: + """Sample random action from all subspaces.""" + key_split = jax.random.split(rng, self.num_spaces) + return tuple( + [ + space.sample(key_split[i]) + for i, space in enumerate(self.spaces) + ] + ) + + def contains(self, x: jnp.int_) -> bool: + """Check whether dimensions of object are within subspace.""" + # type_cond = isinstance(x, tuple) + # num_space_cond = len(x) != len(self.spaces) + # Check for each space individually + out_of_space = 0 + for space in self.spaces: + out_of_space += 1 - space.contains(x) + return out_of_space == 0 + + +class Dummy(object): + def __init__(self, default_value=None): + self._default_value = default_value + self.dtype = jnp.uint32 + if self._default_value is None: + self.n = 0 + else: + self.n = 1 + + def sample(self, rng: chex.PRNGKey): + if self._default_value is None: + return None + else: + return jnp.array(0, dtype=self.dtype) + + def contains(self, x: jnp.int_) -> bool: + return False diff --git a/src/minimax/envs/viz/__init__.py b/src/minimax/envs/viz/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/minimax/envs/viz/grid_rendering.py b/src/minimax/envs/viz/grid_rendering.py new file mode 100644 index 0000000..9f7947f --- /dev/null +++ b/src/minimax/envs/viz/grid_rendering.py @@ -0,0 +1,133 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This file is modified from +https://github.com/Farama-Foundation/Minigrid + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +""" + +import math +import numpy as np + + +def downsample(img, factor): + """ + Downsample an image along both dimensions by some factor + """ + + assert img.shape[0] % factor == 0 + assert img.shape[1] % factor == 0 + + img = img.reshape([img.shape[0]//factor, factor, img.shape[1]//factor, factor, 3]) + img = img.mean(axis=3) + img = img.mean(axis=1) + + return img + +def fill_coords(img, fn, color): + """ + Fill pixels of an image with coordinates matching a filter function + """ + + for y in range(img.shape[0]): + for x in range(img.shape[1]): + yf = (y + 0.5) / img.shape[0] + xf = (x + 0.5) / img.shape[1] + if fn(xf, yf): + img[y, x] = color + + return img + +def rotate_fn(fin, cx, cy, theta): + def fout(x, y): + x = x - cx + y = y - cy + + x2 = cx + x * math.cos(-theta) - y * math.sin(-theta) + y2 = cy + y * math.cos(-theta) + x * math.sin(-theta) + + return fin(x2, y2) + + return fout + +def point_in_line(x0, y0, x1, y1, r): + p0 = np.array([x0, y0]) + p1 = np.array([x1, y1]) + dir = p1 - p0 + dist = np.linalg.norm(dir) + dir = dir / dist + + xmin = min(x0, x1) - r + xmax = max(x0, x1) + r + ymin = min(y0, y1) - r + ymax = max(y0, y1) + r + + def fn(x, y): + # Fast, early escape test + if x < xmin or x > xmax or y < ymin or y > ymax: + return False + + q = np.array([x, y]) + pq = q - p0 + + # Closest point on line + a = np.dot(pq, dir) + a = np.clip(a, 0, dist) + p = p0 + a * dir + + dist_to_line = np.linalg.norm(q - p) + return dist_to_line <= r + + return fn + +def point_in_circle(cx, cy, r): + def fn(x, y): + return (x-cx)*(x-cx) + (y-cy)*(y-cy) <= r * r + return fn + +def point_in_rect(xmin, xmax, ymin, ymax): + def fn(x, y): + return x >= xmin and x <= xmax and y >= ymin and y <= ymax + return fn + +def point_in_triangle(a, b, c): + a = np.array(a) + b = np.array(b) + c = np.array(c) + + def fn(x, y): + v0 = c - a + v1 = b - a + v2 = np.array((x, y)) - a + + # Compute dot products + dot00 = np.dot(v0, v0) + dot01 = np.dot(v0, v1) + dot02 = np.dot(v0, v2) + dot11 = np.dot(v1, v1) + dot12 = np.dot(v1, v2) + + # Compute barycentric coordinates + inv_denom = 1 / (dot00 * dot11 - dot01 * dot01) + u = (dot11 * dot02 - dot01 * dot12) * inv_denom + v = (dot00 * dot12 - dot01 * dot02) * inv_denom + + # Check if point is in triangle + return (u >= 0) and (v >= 0) and (u + v) < 1 + + return fn + +def highlight_img(img, color=(255, 255, 255), alpha=0.30): + """ + Add highlighting to an image + """ + # color = [60, 182, 234] + blend_img = img + alpha * (np.array(color, dtype=np.uint8) - img) + blend_img = blend_img.clip(0, 255).astype(np.uint8) + img[:, :, :] = blend_img diff --git a/src/minimax/envs/viz/grid_viz.py b/src/minimax/envs/viz/grid_viz.py new file mode 100644 index 0000000..d792aa0 --- /dev/null +++ b/src/minimax/envs/viz/grid_viz.py @@ -0,0 +1,272 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This file is modified from +https://github.com/Farama-Foundation/Minigrid + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +""" + +import math + +import numpy as np + +from minimax.envs.viz.window import Window +import minimax.envs.viz.grid_rendering as rendering +from minimax.envs.overcooked_proc.overcooked import OBJECT_TO_INDEX, COLOR_TO_INDEX, COLORS + + +INDEX_TO_COLOR = [k for k, v in COLOR_TO_INDEX.items()] +TILE_PIXELS = 32 + + +class GridVisualizer: + """ + Manages a window and renders contents of EnvState instances to it. + """ + tile_cache = {} + + def __init__(self): + self.window = None + + def _lazy_init_window(self): + if self.window is None: + self.window = Window('minimax') + + def show(self, block=False): + self._lazy_init_window() + self.window.show(block=block) + + def screenshot(self, path): + self._lazy_init_window() + self.window.save_img(path) + + def render(self, params, state, highlight=True, tile_size=TILE_PIXELS, maze_map=None): + return self._render_state(params, state, highlight, tile_size, maze_map) + + def render_grid(self, grid, tile_size=TILE_PIXELS, k_rot90=0, agent_dir_idx=None): + self._lazy_init_window() + + img = GridVisualizer._render_grid( + grid, + tile_size, + highlight_mask=None, + agent_dir_idx=agent_dir_idx, + ) + + if k_rot90 > 0: + img = np.rot90(img, k=k_rot90) + + self.window.show_img(img) + + def _render_state(self, params, state, highlight=True, tile_size=TILE_PIXELS, maze_map=None): + """ + Render the state + """ + self._lazy_init_window() + + if hasattr(params, 'agent_view_size'): + agent_view_size = params.agent_view_size + padding = agent_view_size-2 # show + # padding = 4 + grid = np.asarray( + state.maze_map[padding:-padding, padding:-padding, :]) + else: + assert maze_map is not None, 'Either params contains agent_view_size or explicit maze map is passed in.' + grid = np.asarray(maze_map) + + grid_offset = np.array([1, 1]) + h, w = grid.shape[:2] + + # === Compute highlight mask + if highlight: + highlight_mask = np.zeros(shape=(h, w), dtype=bool) + + f_vec = state.agent_dir + r_vec = np.array([-f_vec[1], f_vec[0]]) + + fwd_bound1 = state.agent_pos + fwd_bound2 = state.agent_pos + state.agent_dir*(agent_view_size-1) + side_bound1 = state.agent_pos - r_vec*(agent_view_size//2) + side_bound2 = state.agent_pos + r_vec*(agent_view_size//2) + + min_bound = np.min(np.stack([ + fwd_bound1, + fwd_bound2, + side_bound1, + side_bound2]) + grid_offset, 0) + + min_y = min(max(min_bound[1], 0), highlight_mask.shape[0]-1) + min_x = min(max(min_bound[0], 0), highlight_mask.shape[1]-1) + + max_y = min( + max(min_bound[1]+agent_view_size, 0), highlight_mask.shape[0]-1) + max_x = min( + max(min_bound[0]+agent_view_size, 0), highlight_mask.shape[1]-1) + + highlight_mask[min_y:max_y, min_x:max_x] = True + + # Render the whole grid + img = GridVisualizer._render_grid( + grid, + tile_size, + highlight_mask=highlight_mask if highlight else None, + agent_dir_idx=state.agent_dir_idx if hasattr( + state, 'agent_dir_idx') else 0 + ) + + self.window.show_img(img) + + @classmethod + def _render_obj( + cls, + obj, + img): + # Render each kind of object + obj_type = obj[0] + color = INDEX_TO_COLOR[obj[1]] + + if obj_type == OBJECT_TO_INDEX['wall']: + rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS[color]) + elif obj_type == OBJECT_TO_INDEX['goal']: + rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS[color]) + elif obj_type == OBJECT_TO_INDEX['agent']: + agent_dir_idx = obj[2] + tri_fn = rendering.point_in_triangle( + (0.12, 0.19), + (0.87, 0.50), + (0.12, 0.81), + ) + tri_fn = rendering.rotate_fn( + tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir_idx) + rendering.fill_coords(img, tri_fn, (255, 0, 0)) + # rendering.fill_coords(img, tri_fn, (60, 182, 234)) + + elif obj_type == OBJECT_TO_INDEX['empty']: + rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS[color]) + elif obj_type == OBJECT_TO_INDEX['lava']: + c = (255, 128, 0) + + # Background color + rendering.fill_coords(img, rendering.point_in_rect(0, 1, 0, 1), c) + + # Little waves + for i in range(3): + ylo = 0.3 + 0.2 * i + yhi = 0.4 + 0.2 * i + rendering.fill_coords(img, rendering.point_in_line( + 0.1, ylo, 0.3, yhi, r=0.03), (0, 0, 0)) + rendering.fill_coords(img, rendering.point_in_line( + 0.3, yhi, 0.5, ylo, r=0.03), (0, 0, 0)) + rendering.fill_coords(img, rendering.point_in_line( + 0.5, ylo, 0.7, yhi, r=0.03), (0, 0, 0)) + rendering.fill_coords(img, rendering.point_in_line( + 0.7, yhi, 0.9, ylo, r=0.03), (0, 0, 0)) + else: + raise ValueError( + f'Rendering object at index {obj_type} is currently unsupported.') + + @classmethod + def _render_tile( + cls, + obj, + highlight=False, + agent_dir_idx=None, + tile_size=TILE_PIXELS, + subdivs=3 + ): + """ + Render a tile and cache the result + """ + # Hash map lookup key for the cache + if obj is not None and \ + obj[0] == OBJECT_TO_INDEX['agent'] and \ + agent_dir_idx is not None: + obj = np.array(obj) + obj[-1] = agent_dir_idx + + no_object = obj is None or ( + obj[0] in [OBJECT_TO_INDEX['empty'], OBJECT_TO_INDEX['unseen']] + and obj[2] == 0 + ) + + if not no_object: + key = (*obj, highlight, tile_size) + else: + key = (obj, highlight, tile_size) + + if key in cls.tile_cache: + return cls.tile_cache[key] + + img = np.zeros(shape=(tile_size * subdivs, + tile_size * subdivs, 3), dtype=np.uint8) + + # Draw the grid lines (top and left edges) + rendering.fill_coords(img, rendering.point_in_rect( + 0, 0.031, 0, 1), (100, 100, 100)) + rendering.fill_coords(img, rendering.point_in_rect( + 0, 1, 0, 0.031), (100, 100, 100)) + + if not no_object: + GridVisualizer._render_obj(obj, img) + + if highlight: + rendering.highlight_img(img) + + # Downsample the image to perform supersampling/anti-aliasing + img = rendering.downsample(img, subdivs) + + # Cache the rendered tile + cls.tile_cache[key] = img + + return img + + @classmethod + def _render_grid( + cls, + grid, + tile_size=TILE_PIXELS, + highlight_mask=None, + agent_dir_idx=None): + if highlight_mask is None: + highlight_mask = np.zeros(shape=grid.shape[:2], dtype=np.bool_) + + # Compute the total grid size in pixels + width_px = grid.shape[1]*tile_size + height_px = grid.shape[0]*tile_size + + img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8) + + # Render the grid + for y in range(grid.shape[0]): + for x in range(grid.shape[1]): + obj = grid[y, x, :] + if obj[0] in [OBJECT_TO_INDEX['empty'], OBJECT_TO_INDEX['unseen']] \ + and obj[2] == 0: + obj = None + + tile_img = GridVisualizer._render_tile( + obj, + highlight=highlight_mask[y, x], + tile_size=tile_size, + agent_dir_idx=agent_dir_idx, + ) + + ymin = y*tile_size + ymax = (y+1)*tile_size + xmin = x*tile_size + xmax = (x+1)*tile_size + img[ymin:ymax, xmin:xmax, :] = tile_img + + return img + + def close(self): + self.window.close() diff --git a/src/minimax/envs/viz/overcooked_visualizer.py b/src/minimax/envs/viz/overcooked_visualizer.py new file mode 100644 index 0000000..022a6b8 --- /dev/null +++ b/src/minimax/envs/viz/overcooked_visualizer.py @@ -0,0 +1,378 @@ +import math + +import numpy as np + +from minimax.envs.viz.window import Window +import minimax.envs.viz.grid_rendering as rendering +from minimax.envs.overcooked_proc.common import OBJECT_TO_INDEX, COLOR_TO_INDEX, COLORS + + +INDEX_TO_COLOR = [k for k, v in COLOR_TO_INDEX.items()] +TILE_PIXELS = 32 + +COLOR_TO_AGENT_INDEX = {0: 0, 2: 1} # Hardcoded. Red is first, blue is second + + +class OvercookedVisualizer: + """ + Manages a window and renders contents of EnvState instances to it. + """ + tile_cache = {} + + def __init__(self): + self.window = None + + def _lazy_init_window(self): + if self.window is None: + self.window = Window('minimax') + + def show(self, block=False): + self._lazy_init_window() + self.window.show(block=block) + + def render(self, agent_view_size, state, highlight=True, tile_size=TILE_PIXELS): + """Method for rendering the state in a window. Esp. useful for interactive mode.""" + return self._render_state(agent_view_size, state, highlight, tile_size) + + def animate(self, state_seq, agent_view_size, filename="animation.gif"): + """Animate a gif give a state sequence and save if to file.""" + import imageio + + padding = agent_view_size - 2 # show + + def get_frame(state): + grid = np.asarray( + state.maze_map[padding:-padding, padding:-padding, :]) + # Render the state + frame = OvercookedVisualizer._render_grid( + grid, + tile_size=TILE_PIXELS, + highlight_mask=None, + agent_dir_idx=state.agent_dir_idx, + agent_inv=state.agent_inv + ) + return frame + + frame_seq = [get_frame(state) for state in state_seq] + + imageio.mimsave(filename, frame_seq, 'GIF', duration=0.5) + + def render_grid(self, grid, tile_size=TILE_PIXELS, k_rot90=0, agent_dir_idx=None): + self._lazy_init_window() + + img = OvercookedVisualizer._render_grid( + grid, + tile_size, + highlight_mask=None, + agent_dir_idx=agent_dir_idx, + ) + # img = np.transpose(img, axes=(1,0,2)) + if k_rot90 > 0: + img = np.rot90(img, k=k_rot90) + + self.window.show_img(img) + + def _render_state(self, agent_view_size, state, highlight=True, tile_size=TILE_PIXELS): + """ + Render the state + """ + self._lazy_init_window() + + padding = agent_view_size-2 # show + grid = np.asarray( + state.maze_map[padding:-padding, padding:-padding, :]) + grid_offset = np.array([1, 1]) + h, w = grid.shape[:2] + # === Compute highlight mask + highlight_mask = np.zeros(shape=(h, w), dtype=bool) + + if highlight: + f_vec = state.agent_dir + r_vec = np.array([-f_vec[1], f_vec[0]]) + + fwd_bound1 = state.agent_pos + fwd_bound2 = state.agent_pos + state.agent_dir*(agent_view_size-1) + side_bound1 = state.agent_pos - r_vec*(agent_view_size//2) + side_bound2 = state.agent_pos + r_vec*(agent_view_size//2) + + min_bound = np.min(np.stack([ + fwd_bound1, + fwd_bound2, + side_bound1, + side_bound2]) + grid_offset, 0) + + min_y = min(max(min_bound[1], 0), highlight_mask.shape[0]-1) + min_x = min(max(min_bound[0], 0), highlight_mask.shape[1]-1) + + max_y = min( + max(min_bound[1]+agent_view_size, 0), highlight_mask.shape[0]-1) + max_x = min( + max(min_bound[0]+agent_view_size, 0), highlight_mask.shape[1]-1) + + highlight_mask[min_y:max_y, min_x:max_x] = True + + # Render the whole grid + img = OvercookedVisualizer._render_grid( + grid, + tile_size, + highlight_mask=highlight_mask if highlight else None, + agent_dir_idx=state.agent_dir_idx, + agent_inv=state.agent_inv + ) + self.window.show_img(img) + + @classmethod + def _render_obj( + cls, + obj, + img): + # Render each kind of object + obj_type = obj[0] + color = INDEX_TO_COLOR[obj[1]] + + if obj_type == OBJECT_TO_INDEX['wall']: + rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS[color]) + elif obj_type == OBJECT_TO_INDEX['goal']: + rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"]) + rendering.fill_coords(img, rendering.point_in_rect( + 0.1, 0.9, 0.1, 0.9), COLORS[color]) + elif obj_type == OBJECT_TO_INDEX['agent']: + agent_dir_idx = obj[2] + tri_fn = rendering.point_in_triangle( + (0.12, 0.19), + (0.87, 0.50), + (0.12, 0.81), + ) + tri_fn = rendering.rotate_fn( + tri_fn, cx=0.5, cy=0.5, theta=0.5*math.pi*agent_dir_idx) + rendering.fill_coords(img, tri_fn, COLORS[color]) + elif obj_type == OBJECT_TO_INDEX['empty']: + rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS[color]) + elif obj_type == OBJECT_TO_INDEX['onion_pile']: + rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"]) + onion_fns = [rendering.point_in_circle(*coord, 0.15) for coord in [(0.5, 0.15), (0.3, 0.4), (0.8, 0.35), + (0.4, 0.8), (0.75, 0.75)]] + [rendering.fill_coords(img, onion_fn, COLORS[color]) + for onion_fn in onion_fns] + elif obj_type == OBJECT_TO_INDEX['onion']: + rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"]) + onion_fn = rendering.point_in_circle(0.5, 0.5, 0.15) + rendering.fill_coords(img, onion_fn, COLORS[color]) + elif obj_type == OBJECT_TO_INDEX['plate_pile']: + rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"]) + plate_fns = [rendering.point_in_circle(*coord, 0.2) for coord in [(0.3, 0.3), (0.75, 0.42), + (0.4, 0.75)]] + [rendering.fill_coords(img, plate_fn, COLORS[color]) + for plate_fn in plate_fns] + elif obj_type == OBJECT_TO_INDEX['plate']: + rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"]) + plate_fn = rendering.point_in_circle(0.5, 0.5, 0.2) + rendering.fill_coords(img, plate_fn, COLORS[color]) + elif obj_type == OBJECT_TO_INDEX['dish']: + rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"]) + plate_fn = rendering.point_in_circle(0.5, 0.5, 0.2) + rendering.fill_coords(img, plate_fn, COLORS[color]) + onion_fn = rendering.point_in_circle(0.5, 0.5, 0.13) + rendering.fill_coords(img, onion_fn, COLORS["orange"]) + elif obj_type == OBJECT_TO_INDEX['pot']: + OvercookedVisualizer._render_pot(obj, img) + # rendering.fill_coords(img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"]) + # pot_fns = [rendering.point_in_rect(0.1, 0.9, 0.3, 0.9), + # rendering.point_in_rect(0.1, 0.9, 0.20, 0.23), + # rendering.point_in_rect(0.4, 0.6, 0.15, 0.20),] + # [rendering.fill_coords(img, pot_fn, COLORS[color]) for pot_fn in pot_fns] + else: + raise ValueError( + f'Rendering object at index {obj_type} is currently unsupported.') + + @classmethod + def _render_pot( + cls, + obj, + img): + pot_status = obj[-1] + num_onions = np.max([23-pot_status, 0]) + is_cooking = np.array((pot_status < 20) * (pot_status > 0)) + is_done = np.array(pot_status == 0) + + pot_fn = rendering.point_in_rect(0.1, 0.9, 0.33, 0.9) + lid_fn = rendering.point_in_rect(0.1, 0.9, 0.21, 0.25) + handle_fn = rendering.point_in_rect(0.4, 0.6, 0.16, 0.21) + + rendering.fill_coords( + img, rendering.point_in_rect(0, 1, 0, 1), COLORS["grey"]) + + # Render onions in pot + if num_onions > 0 and not is_done: + onion_fns = [rendering.point_in_circle( + *coord, 0.13) for coord in [(0.23, 0.33), (0.77, 0.33), (0.50, 0.33)]] + onion_fns = onion_fns[:num_onions] + [rendering.fill_coords(img, onion_fn, COLORS["yellow"]) + for onion_fn in onion_fns] + if not is_cooking: + lid_fn = rendering.rotate_fn( + lid_fn, cx=0.1, cy=0.25, theta=-0.1 * math.pi) + handle_fn = rendering.rotate_fn( + handle_fn, cx=0.1, cy=0.25, theta=-0.1 * math.pi) + + # Render done soup + if is_done: + soup_fn = rendering.point_in_rect(0.12, 0.88, 0.23, 0.35) + rendering.fill_coords(img, soup_fn, COLORS["orange"]) + + # Render the pot itself + pot_fns = [pot_fn, lid_fn, handle_fn] + [rendering.fill_coords(img, pot_fn, COLORS["black"]) + for pot_fn in pot_fns] + + # Render progress bar + if is_cooking: + progress_fn = rendering.point_in_rect( + 0.1, 0.9-(0.9-0.1)/20*pot_status, 0.83, 0.88) + rendering.fill_coords(img, progress_fn, COLORS["green"]) + + @classmethod + def _render_inv( + cls, + obj, + img): + # Render each kind of object + obj_type = obj[0] + if obj_type == OBJECT_TO_INDEX['empty']: + pass + elif obj_type == OBJECT_TO_INDEX['onion']: + onion_fn = rendering.point_in_circle(0.75, 0.75, 0.15) + rendering.fill_coords(img, onion_fn, COLORS["yellow"]) + elif obj_type == OBJECT_TO_INDEX['plate']: + plate_fn = rendering.point_in_circle(0.75, 0.75, 0.2) + rendering.fill_coords(img, plate_fn, COLORS["white"]) + elif obj_type == OBJECT_TO_INDEX['dish']: + plate_fn = rendering.point_in_circle(0.75, 0.75, 0.2) + rendering.fill_coords(img, plate_fn, COLORS["white"]) + onion_fn = rendering.point_in_circle(0.75, 0.75, 0.13) + rendering.fill_coords(img, onion_fn, COLORS["orange"]) + else: + raise ValueError( + f'Rendering object at index {obj_type} is currently unsupported.') + + @classmethod + def _render_tile( + cls, + obj, + highlight=False, + agent_dir_idx=None, + agent_inv=None, + tile_size=TILE_PIXELS, + subdivs=3 + ): + """ + Render a tile and cache the result + """ + # Hash map lookup key for the cache + if obj is not None and obj[0] == OBJECT_TO_INDEX['agent']: + # Get inventory of this specific agent + if agent_inv is not None: + color_idx = obj[1] + agent_inv = agent_inv[COLOR_TO_AGENT_INDEX[color_idx]] + agent_inv = np.array([agent_inv, -1, 0]) + + if agent_dir_idx is not None: + obj = np.array(obj) + + if len(agent_dir_idx) == 1: + # Hacky way of making agent views orientations consistent with global view + obj[-1] = agent_dir_idx[0] + + no_object = obj is None or ( + obj[0] in [OBJECT_TO_INDEX['empty'], OBJECT_TO_INDEX['unseen']] + and obj[2] == 0 + ) + + if not no_object: + if agent_inv is not None and obj[0] == OBJECT_TO_INDEX['agent']: + key = (*obj, *agent_inv, highlight, tile_size) + else: + key = (*obj, highlight, tile_size) + else: + key = (obj, highlight, tile_size) + + if key in cls.tile_cache: + return cls.tile_cache[key] + + img = np.zeros(shape=(tile_size * subdivs, + tile_size * subdivs, 3), dtype=np.uint8) + + # Draw the grid lines (top and left edges) + rendering.fill_coords(img, rendering.point_in_rect( + 0, 0.031, 0, 1), (100, 100, 100)) + rendering.fill_coords(img, rendering.point_in_rect( + 0, 1, 0, 0.031), (100, 100, 100)) + + if not no_object: + OvercookedVisualizer._render_obj(obj, img) + # render inventory + if agent_inv is not None and obj[0] == OBJECT_TO_INDEX['agent']: + OvercookedVisualizer._render_inv(agent_inv, img) + + if highlight: + rendering.highlight_img(img) + + # Downsample the image to perform supersampling/anti-aliasing + img = rendering.downsample(img, subdivs) + + # Cache the rendered tile + cls.tile_cache[key] = img + + return img + + @classmethod + def _render_grid( + cls, + grid, + tile_size=TILE_PIXELS, + highlight_mask=None, + agent_dir_idx=None, + agent_inv=None): + if highlight_mask is None: + highlight_mask = np.zeros(shape=grid.shape[:2], dtype=bool) + + # Compute the total grid size in pixels + width_px = grid.shape[1]*tile_size + height_px = grid.shape[0]*tile_size + + img = np.zeros(shape=(height_px, width_px, 3), dtype=np.uint8) + + # Render the grid + for y in range(grid.shape[0]): + for x in range(grid.shape[1]): + obj = grid[y, x, :] + if obj[0] in [OBJECT_TO_INDEX['empty'], OBJECT_TO_INDEX['unseen']] \ + and obj[2] == 0: + obj = None + + tile_img = OvercookedVisualizer._render_tile( + obj, + highlight=highlight_mask[y, x], + tile_size=tile_size, + agent_dir_idx=agent_dir_idx, + agent_inv=agent_inv, + ) + + ymin = y*tile_size + ymax = (y+1)*tile_size + xmin = x*tile_size + xmax = (x+1)*tile_size + img[ymin:ymax, xmin:xmax, :] = tile_img + + return img + + def close(self): + self.window.close() diff --git a/src/minimax/envs/viz/window.py b/src/minimax/envs/viz/window.py new file mode 100644 index 0000000..d97665b --- /dev/null +++ b/src/minimax/envs/viz/window.py @@ -0,0 +1,107 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This file is modified from +https://github.com/Farama-Foundation/Minigrid + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +""" + +import sys +import numpy as np + +# Only ask users to install matplotlib if they actually need it +try: + import matplotlib.pyplot as plt +except: + print('To display the environment in a window, please install matplotlib, eg:') + print('pip3 install --user matplotlib') + sys.exit(-1) + +class Window: + """ + Window to draw a gridworld instance using Matplotlib + """ + + def __init__(self, title): + self.fig = None + + self.imshow_obj = None + + # Create the figure and axes + self.fig, self.ax = plt.subplots() + + # Show the env name in the window title + self.fig.canvas.manager.set_window_title(title) + + # Turn off x/y axis numbering/ticks + self.ax.set_xticks([], []) + self.ax.set_yticks([], []) + + # Flag indicating the window was closed + self.closed = False + + def close_handler(evt): + self.closed = True + + self.fig.canvas.mpl_connect('close_event', close_handler) + + def show_img(self, img): + """ + Show an image or update the image being shown + """ + + # Show the first image of the environment + if self.imshow_obj is None: + self.imshow_obj = self.ax.imshow(img, interpolation='bilinear') + + self.imshow_obj.set_data(img) + self.fig.canvas.draw() + + # Let matplotlib process UI events + # This is needed for interactive mode to work properly + plt.pause(0.001) + + def save_img(self, path): + plt.savefig(path, bbox_inches='tight', pad_inches=0) + + def set_caption(self, text): + """ + Set/update the caption text below the image + """ + + plt.xlabel(text) + + def reg_key_handler(self, key_handler): + """ + Register a keyboard event handler + """ + + # Keyboard handler + self.fig.canvas.mpl_connect('key_press_event', key_handler) + + def show(self, block=True): + """ + Show the window, and start an event loop + """ + + # If not blocking, trigger interactive mode + if not block: + plt.ion() + + # Show the plot + # In non-interative mode, this enters the matplotlib event loop + # In interactive mode, this call does not block + plt.show() + + def close(self): + """ + Close the window + """ + + plt.close() diff --git a/src/minimax/envs/wrappers/__init__.py b/src/minimax/envs/wrappers/__init__.py new file mode 100644 index 0000000..d53f4e9 --- /dev/null +++ b/src/minimax/envs/wrappers/__init__.py @@ -0,0 +1,14 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .env_wrapper import EnvWrapper +from .monitor_return import MonitorReturnWrapper +from .monitor_ep_metrics import MonitorEpisodicMetricsWrapper +from .world_state_wrapper import WorldStateWrapper + +from .ued_env_wrapper import UEDEnvWrapper \ No newline at end of file diff --git a/src/minimax/envs/wrappers/env_wrapper.py b/src/minimax/envs/wrappers/env_wrapper.py new file mode 100644 index 0000000..37e47c7 --- /dev/null +++ b/src/minimax/envs/wrappers/env_wrapper.py @@ -0,0 +1,114 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + +import jax +import chex +from typing import Tuple, Union, Optional + +from minimax.envs.environment import EnvState + + +class EnvWrapper: + """ + Abstract base class for an env wrapper. + """ + + def __init__(self, env): + self._env = env + + self._wrap_level = 1 + while hasattr(env, '_env'): + if isinstance(env, EnvWrapper): + self._wrap_level += 1 + + env = env._env + + @classmethod + def is_compatible(cls, env): + return True + + @property + def base_env(self): + env = self + for i in range(self._wrap_level): + env = env._env + + return env + + def reset_extra(self): + return {} + + def get_monitored_metrics(self): + # breakpoint() + # if self._wrap_level > 1: + return self._env.get_monitored_metrics() + # return () + + def _append_extra_to_tuple(self, _tuple, extra=None): + if extra is None: + extra = self.reset_extra() + + if self._wrap_level > 1 and len(_tuple) > 2: + _tuple[-1].update(extra) + else: + _tuple = _tuple + (extra,) + + return _tuple + + def step( + self, + key: chex.PRNGKey, + state: EnvState, + action: Union[int, float], + reset_state: Optional[chex.ArrayTree] = None, + extra: dict = None, + ) -> Tuple[chex.Array, EnvState, float, bool]: + if self._wrap_level > 1: + return self._env.step(key, state, action, reset_state, extra) + else: + _tuple = self._env.step( + key, state, action, reset_state=reset_state + ) + return self._append_extra_to_tuple(_tuple, extra) + + def reset( + self, + key: chex.PRNGKey, + ) -> Tuple[chex.Array, EnvState, chex.ArrayTree]: + _tuple = self._env.reset(key) + return self._append_extra_to_tuple(_tuple) + + def set_state( + self, + state: EnvState, + ) -> Tuple[chex.ArrayTree, EnvState, chex.ArrayTree]: + _tuple = self._env.set_state(state) + + return self._append_extra_to_tuple(_tuple) + + def set_env_instance( + self, + encoding: chex.ArrayTree, + ) -> Tuple[chex.ArrayTree, EnvState, chex.ArrayTree]: + _tuple = self._env.set_env_instance(encoding) + + return self._append_extra_to_tuple(_tuple) + + def reset_student( + self, + key: chex.PRNGKey, + state: chex.ArrayTree, + ) -> Tuple[chex.ArrayTree, EnvState, chex.ArrayTree]: + _tuple = self._env.reset_student(key, state) + + return self._append_extra_to_tuple(_tuple) + + def __getattr__(self, attr): + return getattr(self._env, attr) diff --git a/src/minimax/envs/wrappers/monitor_ep_metrics.py b/src/minimax/envs/wrappers/monitor_ep_metrics.py new file mode 100644 index 0000000..a7714f3 --- /dev/null +++ b/src/minimax/envs/wrappers/monitor_ep_metrics.py @@ -0,0 +1,84 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + +import jax +import jax.numpy as jnp +import chex +from typing import Tuple, Union, Optional + +from .env_wrapper import EnvWrapper +from minimax.envs.environment import EnvState + + +class MonitorEpisodicMetricsWrapper(EnvWrapper): + """ + Tracks episodic metrics about environment instances. + """ + def __init__(self, env): + super().__init__(env) + + base_env = env.base_env if hasattr(env, 'base_env') else env + _env = base_env.env if hasattr(base_env, 'env') else base_env + + self.metrics = () + if hasattr(_env, 'get_episodic_metrics'): + reset_tuple = _env.reset(jax.random.PRNGKey(0)) + dummy_state = reset_tuple[1] + + self.metrics = tuple({ + k: jnp.zeros_like(v) \ + for k,v in _env.get_episodic_metrics(dummy_state).items() + }.keys()) + + @classmethod + def is_compatible(cls, env): + _env = env.env if hasattr(env, 'env') else env + return hasattr(_env, 'get_episodic_metrics') + + def get_monitored_metrics(self): + metrics = tuple(f'ep/{m}' for m in self.metrics) + if self._wrap_level > 1: + return self._env.get_monitored_metrics() + metrics + else: + return self.metrics + + def step( + self, + key: chex.PRNGKey, + state: EnvState, + action: Union[int, float], + reset_state: Optional[chex.ArrayTree] = None, + extra: dict = None, + ) -> Tuple[chex.Array, EnvState, float, bool]: + step_kwargs = dict( + reset_state=reset_state + ) + if self._wrap_level > 1: + step_kwargs.update(dict( + extra=extra + )) + + step = self._env.step( + key, + state, + action, + **step_kwargs) + + if len(step) == 5: + obs, state, reward, done, info = step + else: + obs, state, reward, done, info, extra = step + + if len(self.metrics) > 0: + for m in self.metrics: + info[f'ep/{m}'] = info[m] + del info[m] + + return obs, state, reward, done, info, extra diff --git a/src/minimax/envs/wrappers/monitor_return.py b/src/minimax/envs/wrappers/monitor_return.py new file mode 100644 index 0000000..206eb62 --- /dev/null +++ b/src/minimax/envs/wrappers/monitor_return.py @@ -0,0 +1,76 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + +import jax +import chex +from typing import Tuple, Union, Optional + +from .env_wrapper import EnvWrapper +from minimax.envs.environment import EnvState + + +class MonitorReturnWrapper(EnvWrapper): + """ + Tracks episodic returns and, optionally, environment metrics. + """ + + def reset_extra(self): + if self._wrap_level > 1: + extra = self._env.reset_extra() + else: + extra = {} + + extra.update({ + 'ep_return': 0., + }) + + return extra + + def get_monitored_metrics(self): + return super().get_monitored_metrics() + ('return',) + + def step( + self, + key: chex.PRNGKey, + state: EnvState, + action: Union[int, float], + reset_state: Optional[chex.ArrayTree] = None, + extra: dict = None, + ) -> Tuple[chex.Array, EnvState, float, bool]: + step_kwargs = dict( + reset_state=reset_state + ) + if self._wrap_level > 1: + step_kwargs.update(dict( + extra=extra + )) + + step = self._env.step( + key, + state, + action, + **step_kwargs) + + if len(step) == 5: + obs, state, reward, done, info = step + else: + obs, state, reward, done, info, extra = step + + if type(reward) == dict: + reward = reward['agent_0'] # NOTE: Fully Cooperative taks + + if type(done) == dict: + done = done['__all__'] + # Track returns + extra['ep_return'] += reward + info['return'] = done*extra['ep_return'] + extra['ep_return'] *= (1-done) + + return obs, state, reward, done, info, extra diff --git a/src/minimax/envs/wrappers/ued_env_wrapper.py b/src/minimax/envs/wrappers/ued_env_wrapper.py new file mode 100644 index 0000000..9f96924 --- /dev/null +++ b/src/minimax/envs/wrappers/ued_env_wrapper.py @@ -0,0 +1,85 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + +import jax +import chex +from typing import Tuple, Union, Optional + +from minimax.envs.environment import Environment, EnvState + + +class UEDEnvWrapper: + """ + Abstract base class for an env wrapper. + """ + def __init__(self, env): + self._env = env + + self._wrap_level = 1 + while hasattr(env, '_env'): + if not isinstance(env, Environment): + self._wrap_level += 1 + + env = env._env + + @classmethod + def is_compatible(cls, env): + return True + + @property + def base_env(self): + env = self + for i in range(self._wrap_level): + env = env._env + + return env + + def reset_extra(self): + return {} + + def get_monitored_metrics(self): + if self._wrap_level > 1: + return self._env.get_monitored_metrics() + return () + + def _append_extra_to_tuple(self, _tuple, extra=None): + if extra is None: + extra = self.reset_extra() + + if self._wrap_level > 1: + _tuple[-1].update(extra) + else: + _tuple = _tuple + (extra,) + + return _tuple + + def reset_teacher( + self, + rng: chex.PRNGKey + ) -> Tuple[chex.ArrayTree, EnvState]: + _tuple = self._env.reset_teacher(rng) + + return self._append_extra_to_tuple(_tuple) + + def step_teacher( + self, + rng: chex.PRNGKey, + ued_state: EnvState, + action: Union[int, float], + extra: dict = None, + ) -> Tuple[chex.ArrayTree, EnvState, float, bool, dict]: + if self._wrap_level > 1: + return self._env.step_teacher(rng, ued_state, action, extra) + else: + _tuple = self._env.step_teacher(rng, ued_state, action) + return self._append_extra_to_tuple(_tuple) + + def __getattr__(self, attr): + return getattr(self._env, attr) diff --git a/src/minimax/envs/wrappers/world_state_wrapper.py b/src/minimax/envs/wrappers/world_state_wrapper.py new file mode 100644 index 0000000..06ea805 --- /dev/null +++ b/src/minimax/envs/wrappers/world_state_wrapper.py @@ -0,0 +1,160 @@ +from functools import partial +import jax +import jax.numpy as jnp + +import chex +from typing import Union, Optional + +from minimax.envs.environment import EnvState + +from minimax.envs import environment +from minimax.envs.wrappers.env_wrapper import EnvWrapper + + +class JaxMARLWrapper(object): + """Base class for all jaxmarl wrappers. + Copied from the JaxMARL project: https://github.com/FLAIROx/JaxMARL + """ + + def __init__(self, env: environment.Environment): + self._env = env + + def __getattr__(self, name: str): + return getattr(self._env, name) + + def _batchify_floats(self, x: dict): + return jnp.stack([x[a] for a in self._env.agents]) + + +class WorldStateWrapper(EnvWrapper): + + def __init__(self, env): + self._env = env + + self._wrap_level = 1 + while hasattr(env, '_env'): + if isinstance(env, EnvWrapper): + self._wrap_level += 1 + + env = env._env + + def __getattr__(self, name: str): + return getattr(self._env, name) + + def _batchify_floats(self, x: dict): + return jnp.stack([x[a] for a in self._env.agents]) + + @partial(jax.jit, static_argnums=0) + def world_state(self, obs): + """ + For each agent: [agent obs, all other agent obs] + + NOTE: This assumes two agents! + """ + # This is consistent with the OvercookedEnv implementation. + world_state_0 = jnp.concatenate( + [obs['agent_0'], obs['agent_1']], axis=-1) + world_state_1 = jnp.concatenate( + [obs['agent_1'], obs['agent_0']], axis=-1) + + return { + 'agent_0': world_state_0, + 'agent_1': world_state_1 + } + + @partial(jax.jit, static_argnums=0) + def reset(self, key): + res = self._env.reset(key) + obs = res[0] + world_state = self.world_state(obs) + obs["world_state"] = world_state + _tuple = (obs, *res[1:]) + return self._append_extra_to_tuple(_tuple) + + @partial(jax.jit, static_argnums=0) + def step(self, + key: chex.PRNGKey, + state: EnvState, + action: Union[int, float], + reset_state: Optional[chex.ArrayTree] = None, + extra: dict = None, + **kwargs): + if self._wrap_level > 1: + obs, env_state, reward, done, info = self._env.step( + key, state, action, **kwargs + ) + world_state = self.world_state(obs) + obs["world_state"] = world_state + return obs, env_state, reward, done, info + else: + obs, env_state, reward, done, info = self._env.step( + key, state, action, **kwargs + ) + world_state = self.world_state(obs) + obs["world_state"] = world_state + _tuple = (obs, env_state, reward, done, info) + return self._append_extra_to_tuple(_tuple, extra) + + @partial(jax.jit, static_argnums=0) + def set_state(self, state): + if self._wrap_level > 1: + obs, state = self._env.set_state(state) + world_state = self.world_state(obs) + obs["world_state"] = world_state + return obs, state + else: + obs, state = self._env.set_state(state) + world_state = self.world_state(obs) + obs["world_state"] = world_state + _tuple = (obs, state) + return self._append_extra_to_tuple(_tuple) + + @partial(jax.jit, static_argnums=0) + def reset_student( + self, + key, + state + ): + res = self._env.reset_student(key, state) + obs = res[0] + world_state = self.world_state(obs) + obs["world_state"] = world_state + return obs, *res[1:] + + def world_state_size(self): + spaces = [ + jnp.zeros(self._env.observation_space().shape) for _ in self._env.agents] + y = jnp.concatenate(spaces, axis=-1).shape + return y + + def reset_extra(self): + if self._wrap_level > 1: + extra = self._env.reset_extra() + else: + extra = {} + return extra + + def reset_teacher( + self, + rng + ): + _tuple = self._env.reset_teacher(rng) + + return self._append_extra_to_tuple(_tuple) + + def step_teacher( + self, + rng, + ued_state, + action, + extra: dict = None, + ): + if self._wrap_level > 1: + return self._env.step_teacher(rng, ued_state, action, extra) + else: + _tuple = self._env.step_teacher(rng, ued_state, action) + return self._append_extra_to_tuple(_tuple) + + @classmethod + def is_compatible(cls, env): + return env.name == "Overcooked" diff --git a/src/minimax/evaluate.py b/src/minimax/evaluate.py new file mode 100644 index 0000000..2de0b33 --- /dev/null +++ b/src/minimax/evaluate.py @@ -0,0 +1,244 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import json +import re +import fnmatch +import sys +from collections import defaultdict + +import numpy as np +import pandas as pd +import scipy.stats as spstats +import jax +import jax.numpy as jnp +from tqdm import tqdm + +from minimax.util.parsnip import Parsnip +from minimax.util.checkpoint import load_pkl_object, load_config +from minimax.util.loggers import HumanOutputFormat +from minimax.util.rl import AgentPop +import minimax.models as models +import minimax.agents as agents + + +parser = Parsnip() + +# ==== Define top-level arguments +parser.add_argument( + '--seed', + type=int, + default=1, + help='Random seed.') +parser.add_argument( + '--log_dir', + type=str, + default='~/logs/minimax', + help='Log directory containing experiment dirs.') +parser.add_argument( + '--xpid', + type=str, + default='latest', + help='Experiment ID dir name for model.') +parser.add_argument( + '--xpid_prefix', + type=str, + default=None, + help='Experiment ID dir name for model.') +parser.add_argument( + '--checkpoint_name', + type=str, + default='checkpoint', + help='Name of checkpoint .pkl.') +parser.add_argument( + '--env_names', + type=str, + help='csv of evaluation environments.') +parser.add_argument( + '--n_episodes', + type=int, + default=1, + help='Number of evaluation episodes.') +parser.add_argument( + '--agent_idxs', + type=str, + default='*', + help="Indices of agents to evaluate. '*' indicates all.") +parser.add_argument( + '--render_mode', + type=str, + nargs='?', + const=True, + default=None, + help='Visualize episodes.') +parser.add_argument( + '--results_path', + type=str, + default='results/', + help='Results dir.') +parser.add_argument( + '--results_fname', + type=str, + default=None, + help='Results filename (without .csv).') + + +if __name__ == '__main__': + """ + Usage: + python -m eval \ + --xpid= \ + --env_names="Maze-SixteenRooms" \ + --n_episodes=100 \ + --agent_idxs=0 + """ + args = parser.parse_args() + + log_dir_path = os.path.expandvars(os.path.expanduser(args.log_dir)) + + xpids = [] + if args.xpid_prefix is not None: + # Get all matching xpid directories + all_xpids = fnmatch.filter(os.listdir( + log_dir_path), f"{args.xpid_prefix}*") + filter_re = re.compile('.*_[0-9]*$') + xpids = [x for x in all_xpids if filter_re.match(x)] + else: + xpids = [args.xpid] + + pbar = tqdm(total=len(xpids)) + + all_eval_stats = defaultdict(list) + for xpid in xpids: + xpid_dir_path = os.path.join(log_dir_path, xpid) + checkpoint_path = os.path.join( + xpid_dir_path, f'{args.checkpoint_name}.pkl') + meta_path = os.path.join(xpid_dir_path, f'meta.json') + + # Load checkpoint info + if not os.path.exists(meta_path): + print(f'Configuration at {meta_path} does not exist. Skipping...') + continue + + if not os.path.exists(checkpoint_path): + print( + f'Checkpoint path {checkpoint_path} does not exist. Skipping...') + continue + + xp_args = load_config(meta_path) + + agent_idxs = args.agent_idxs + if agent_idxs == '*': + agent_idxs = np.arange(xp_args.train_runner_args.n_students) + else: + agent_idxs = \ + np.array([int(x) for x in agent_idxs.split(',')]) + assert np.max(agent_idxs) <= xp_args.train_runner_args.n_students, \ + 'Agent index is out of bounds.' + + runner_state = load_pkl_object(checkpoint_path) + if "params" in runner_state[1].keys(): + params = runner_state[1]['params'] + elif "actor_params" in runner_state[1].keys(): + params = runner_state[1]['actor_params'] + else: + raise ValueError("No params found in checkpoint.") + + params = jax.tree_util.tree_map( + lambda x: jnp.take(x, indices=agent_idxs, axis=0), + params + ) + + with jax.disable_jit(args.render_mode is not None): + student_model = models.make( + env_name=xp_args.env_name, + model_name=xp_args.student_model_name, + **xp_args.student_model_args + ) + + # We force EvalRunner to select all params, since we've already + # extracted the relevant agent indices. + if "Overcooked" in args.env_names: + from minimax.runners_ma import EvalRunner + + pop = AgentPop( + agent=agents.MAPPOAgent(actor=student_model, critic=None), + n_agents=len(agent_idxs) + ) + elif "Maze" in args.env_names: + from minimax.runners import EvalRunner + + pop = AgentPop( + agent=agents.PPOAgent(model=student_model), + n_agents=len(agent_idxs) + ) + else: + raise ValueError("Unknown environment.") + + runner = EvalRunner( + pop=pop, + env_names=args.env_names, + env_kwargs=xp_args.eval_env_args, + n_episodes=args.n_episodes, + render_mode=args.render_mode, + agent_idxs='0' + ) + + rng = jax.random.PRNGKey(args.seed) + _eval_stats = runner.run(rng, params) + + eval_stats = {} + for k, v in _eval_stats.items(): + prefix_match = re.match(r'^eval/(a[0-9]+):.*', k) + if prefix_match is not None: + prefix = prefix_match.groups()[0] + _idx = int(prefix.lstrip('a').rstrip(':')) + idx = agent_idxs[_idx] + new_prefix = f'a{idx}' + new_k = k.replace(prefix, new_prefix) + eval_stats[new_k] = v + else: + eval_stats[k] = v + + for k, v in eval_stats.items(): + all_eval_stats[k].append(float(v)) + + pbar.update(1) + + pbar.close() + + aggregate_eval_stats = {} + for k, v in all_eval_stats.items(): + mean = np.mean(all_eval_stats[k]) + if len(all_eval_stats[k]) > 1: + sem = spstats.sem(all_eval_stats[k]) + else: + sem = 0.0 + aggregate_eval_stats[k] = f'{mean: 0.4}+/-{sem: 0.4}' + + _min = np.min(all_eval_stats[k]) + aggregate_eval_stats[f'min:{k}'] = f'{_min: 0.4}' + + logger = HumanOutputFormat(sys.stdout) + logger.writekvs(aggregate_eval_stats) + + if args.results_fname is not None: + if args.results_fname.strip('"') == '*': + results_fname = args.xpid_prefix or args.xpid + else: + results_fname = args.results_fname + + df = pd.DataFrame.from_dict(all_eval_stats) + results_path = args.results_path + if not os.path.isabs(results_path): + results_path = os.path.join( + os.path.abspath(__file__), results_path) + results_path = os.path.join(results_path, f'{results_fname}.csv') + df.to_csv(results_path, index=False) + print(f'Saved results to {results_path}') diff --git a/src/minimax/evaluate_against_baseline.py b/src/minimax/evaluate_against_baseline.py new file mode 100644 index 0000000..a79573c --- /dev/null +++ b/src/minimax/evaluate_against_baseline.py @@ -0,0 +1,276 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import json +import re +import fnmatch +import sys +from collections import defaultdict + +import numpy as np +import pandas as pd +import scipy.stats as spstats +import jax +import jax.numpy as jnp +from tqdm import tqdm + +from minimax.util.parsnip import Parsnip +from minimax.util.checkpoint import load_pkl_object, load_config +from minimax.util.loggers import HumanOutputFormat +from minimax.util.rl import AgentPop +import minimax.models as models +import minimax.agents as agents + + +class FixedModel(nn.Module): + """Useful as a model that acts randomly or always takes a certain action. + We use it to establish a baseline for cooperation. + """ + is_random: bool = False + always_pick_action: int = None + num_actions: int = 6 + + def setup(self): + super().__init__() + + def __call__(self, x, carry=None, reset=None): + if self.is_random: + # Same logits for all actions + logits = jnp.ones((x.shape[0], self.num_actions)) + return logits, carry + # Logits one for one action always + logits = jnp.zeros((x.shape[0], self.num_actions)) + return logits.at[:, self.always_pick_action].set(jnp.inf), carry + + def initialize_carry( + self, + rng: chex.PRNGKey, + batch_dims: Tuple[int] = ()) -> Tuple[chex.ArrayTree, chex.ArrayTree]: + """Initialize hidden state of LSTM.""" + return None + + @property + def is_recurrent(self): + return False + + +parser = Parsnip() + +# ==== Define top-level arguments +parser.add_argument( + '--seed', + type=int, + default=1, + help='Random seed.') +parser.add_argument( + '--log_dir', + type=str, + default='~/logs/minimax', + help='Log directory containing experiment dirs.') +parser.add_argument( + '--xpid', + type=str, + default='latest', + help='Experiment ID dir name for model.') +parser.add_argument( + '--xpid_prefix', + type=str, + default=None, + help='Experiment ID dir name for model.') +parser.add_argument( + '--checkpoint_name', + type=str, + default='checkpoint', + help='Name of checkpoint .pkl.') +parser.add_argument( + '--env_names', + type=str, + help='csv of evaluation environments.') +parser.add_argument( + '--n_episodes', + type=int, + default=1, + help='Number of evaluation episodes.') +parser.add_argument( + '--agent_idxs', + type=str, + default='*', + help="Indices of agents to evaluate. '*' indicates all.") +parser.add_argument( + '--render_mode', + type=str, + nargs='?', + const=True, + default=None, + help='Visualize episodes.') +parser.add_argument( + '--results_path', + type=str, + default='results/', + help='Results dir.') +parser.add_argument( + '--results_fname', + type=str, + default=None, + help='Results filename (without .csv).') + + +if __name__ == '__main__': + """ + Usage: + python -m eval \ + --xpid= \ + --env_names="Maze-SixteenRooms" \ + --n_episodes=100 \ + --agent_idxs=0 + """ + args = parser.parse_args() + + log_dir_path = os.path.expandvars(os.path.expanduser(args.log_dir)) + + xpids = [] + if args.xpid_prefix is not None: + # Get all matching xpid directories + all_xpids = fnmatch.filter(os.listdir( + log_dir_path), f"{args.xpid_prefix}*") + filter_re = re.compile('.*_[0-9]*$') + xpids = [x for x in all_xpids if filter_re.match(x)] + else: + xpids = [args.xpid] + + pbar = tqdm(total=len(xpids)) + + all_eval_stats = defaultdict(list) + for xpid in xpids: + xpid_dir_path = os.path.join(log_dir_path, xpid) + checkpoint_path = os.path.join( + xpid_dir_path, f'{args.checkpoint_name}.pkl') + meta_path = os.path.join(xpid_dir_path, f'meta.json') + + # Load checkpoint info + if not os.path.exists(meta_path): + print(f'Configuration at {meta_path} does not exist. Skipping...') + continue + + if not os.path.exists(checkpoint_path): + print( + f'Checkpoint path {checkpoint_path} does not exist. Skipping...') + continue + + xp_args = load_config(meta_path) + + agent_idxs = args.agent_idxs + if agent_idxs == '*': + agent_idxs = np.arange(xp_args.train_runner_args.n_students) + else: + agent_idxs = \ + np.array([int(x) for x in agent_idxs.split(',')]) + assert np.max(agent_idxs) <= xp_args.train_runner_args.n_students, \ + 'Agent index is out of bounds.' + + runner_state = load_pkl_object(checkpoint_path) + if "params" in runner_state[1].keys(): + params = runner_state[1]['params'] + elif "actor_params" in runner_state[1].keys(): + params = runner_state[1]['actor_params'] + else: + raise ValueError("No params found in checkpoint.") + + params = jax.tree_util.tree_map( + lambda x: jnp.take(x, indices=agent_idxs, axis=0), + params + ) + + with jax.disable_jit(args.render_mode is not None): + student_model = models.make( + env_name=xp_args.env_name, + model_name=xp_args.student_model_name, + **xp_args.student_model_args + ) + + # We force EvalRunner to select all params, since we've already + # extracted the relevant agent indices. + if "Overcooked" in args.env_names: + from minimax.runners_ma import EvalRunner + + pop = AgentPop( + agent=agents.MAPPOAgent(actor=student_model, critic=None), + n_agents=len(agent_idxs) + ) + elif "Maze" in args.env_names: + from minimax.runners import EvalRunner + + pop = AgentPop( + agent=agents.PPOAgent(model=student_model), + n_agents=len(agent_idxs) + ) + else: + raise ValueError("Unknown environment.") + + runner = EvalRunner( + pop=pop, + env_names=args.env_names, + env_kwargs=xp_args.eval_env_args, + n_episodes=args.n_episodes, + render_mode=args.render_mode, + agent_idxs='*' + ) + + rng = jax.random.PRNGKey(args.seed) + _eval_stats = runner.run(rng, params) + + eval_stats = {} + for k, v in _eval_stats.items(): + prefix_match = re.match(r'^eval/(a[0-9]+):.*', k) + if prefix_match is not None: + prefix = prefix_match.groups()[0] + _idx = int(prefix.lstrip('a').rstrip(':')) + idx = agent_idxs[_idx] + new_prefix = f'a{idx}' + new_k = k.replace(prefix, new_prefix) + eval_stats[new_k] = v + else: + eval_stats[k] = v + + for k, v in eval_stats.items(): + all_eval_stats[k].append(float(v)) + + pbar.update(1) + + pbar.close() + + aggregate_eval_stats = {} + for k, v in all_eval_stats.items(): + mean = np.mean(all_eval_stats[k]) + if len(all_eval_stats[k]) > 1: + sem = spstats.sem(all_eval_stats[k]) + else: + sem = 0.0 + aggregate_eval_stats[k] = f'{mean: 0.4}+/-{sem: 0.4}' + + _min = np.min(all_eval_stats[k]) + aggregate_eval_stats[f'min:{k}'] = f'{_min: 0.4}' + + logger = HumanOutputFormat(sys.stdout) + logger.writekvs(aggregate_eval_stats) + + if args.results_fname is not None: + if args.results_fname.strip('"') == '*': + results_fname = args.xpid_prefix or args.xpid + else: + results_fname = args.results_fname + + df = pd.DataFrame.from_dict(all_eval_stats) + results_path = args.results_path + if not os.path.isabs(results_path): + results_path = os.path.join( + os.path.abspath(__file__), results_path) + results_path = os.path.join(results_path, f'{results_fname}.csv') + df.to_csv(results_path, index=False) + print(f'Saved results to {results_path}') diff --git a/src/minimax/evaluate_against_population.py b/src/minimax/evaluate_against_population.py new file mode 100644 index 0000000..649323f --- /dev/null +++ b/src/minimax/evaluate_against_population.py @@ -0,0 +1,281 @@ +import os +import json +import re +import fnmatch +import sys +from collections import defaultdict + +import numpy as np +import pandas as pd +import scipy.stats as spstats +import jax +import jax.numpy as jnp +from tqdm import tqdm + +from minimax.util.parsnip import Parsnip +from minimax.util.checkpoint import load_pkl_object, load_config +from minimax.util.loggers import HumanOutputFormat +from minimax.util.rl import AgentPopHeterogenous +import minimax.models as models +import minimax.agents as agents + + +parser = Parsnip() + +# ==== Define top-level arguments +parser.add_argument( + '--seed', + type=int, + default=1, + help='Random seed.') +parser.add_argument( + '--population_json', + type=str, + default=None, + help='Path to population json file.') +parser.add_argument( + '--log_dir', + type=str, + default='~/logs/minimax', + help='Log directory containing experiment dirs.') +parser.add_argument( + '--xpid', + type=str, + default='latest', + help='Experiment ID dir name for model.') +parser.add_argument( + '--xpid_prefix', + type=str, + default=None, + help='Experiment ID dir name for model.') +parser.add_argument( + '--checkpoint_name', + type=str, + default='checkpoint', + help='Name of checkpoint .pkl.') +parser.add_argument( + '--env_names', + type=str, + help='csv of evaluation environments.') +parser.add_argument( + '--n_episodes', + type=int, + default=1, + help='Number of evaluation episodes.') +parser.add_argument( + '--agent_idxs', + type=str, + default='*', + help="Indices of agents to evaluate. '*' indicates all.") +parser.add_argument( + '--render_mode', + type=str, + nargs='?', + const=True, + default=None, + help='Visualize episodes.') +parser.add_argument( + '--results_path', + type=str, + default='results/', + help='Results dir.') +parser.add_argument( + '--results_fname', + type=str, + default=None, + help='Results filename (without .csv).') + + +if __name__ == '__main__': + """ + Usage: + python -m eval \ + --xpid= \ + --env_names="Maze-SixteenRooms" \ + --n_episodes=100 \ + --agent_idxs=0 + """ + args = parser.parse_args() + + log_dir_path = os.path.expandvars(os.path.expanduser(args.log_dir)) + + xpid = args.xpid + + population_json_path = args.population_json + + with open(population_json_path, 'r') as f: + population = json.load(f) + + population_size = int(population["population_size"]) + + pbar = tqdm(total=population_size*2) + + all_eval_stats = defaultdict(list) + for agent_id in range(1, population_size+1): + xpid_dir_path = os.path.join(log_dir_path, xpid) + checkpoint_path = os.path.join( + xpid_dir_path, f'{args.checkpoint_name}.pkl') + meta_path = os.path.join(xpid_dir_path, f'meta.json') + + other_agent_checkpoint_path = f"{os.getcwd()}/{population[str(agent_id)]}" + other_agent_meta_path = f"{os.getcwd()}/{population[f'{agent_id}_meta']}" + + # Load checkpoint info + if not os.path.exists(meta_path): + print(f'Configuration at {meta_path} does not exist. Skipping...') + continue + + if not os.path.exists(other_agent_meta_path): + raise ValueError(f"Did not find: {other_agent_meta_path}") + + if not os.path.exists(checkpoint_path): + print( + f'Checkpoint path {checkpoint_path} does not exist. Skipping...') + continue + + if not os.path.exists(other_agent_checkpoint_path): + raise ValueError(f"Did not find: {other_agent_checkpoint_path}") + + xp_args = load_config(meta_path) + + xp_population_args = load_config(other_agent_meta_path) + + agent_idxs = args.agent_idxs + if agent_idxs == '*': + agent_idxs = np.arange(xp_args.train_runner_args.n_students) + else: + agent_idxs = \ + np.array([int(x) for x in agent_idxs.split(',')]) + assert np.max(agent_idxs) <= xp_args.train_runner_args.n_students, \ + 'Agent index is out of bounds.' + + runner_state_0 = load_pkl_object(checkpoint_path) + if "params" in runner_state_0[1].keys(): + params_0 = runner_state_0[1]['params'] + elif "actor_params" in runner_state_0[1].keys(): + params_0 = runner_state_0[1]['actor_params'] + else: + raise ValueError("No params found in checkpoint.") + + params_0 = jax.tree_util.tree_map( + lambda x: jnp.take(x, indices=agent_idxs, axis=0), + params_0 + ) + + xp_args_other = load_config(other_agent_meta_path) + + runner_state_1 = load_pkl_object(other_agent_checkpoint_path) + if "params" in runner_state_1[1].keys(): + params_1 = runner_state_1[1]['params'] + elif "actor_params" in runner_state_1[1].keys(): + params_1 = runner_state_1[1]['actor_params'] + else: + raise ValueError("No params found in checkpoint.") + + params_1 = jax.tree_util.tree_map( + lambda x: jnp.take(x, indices=agent_idxs, axis=0), + params_1 + ) + + for i in range(2): + # Swap params and runner states + # Bit finicky be careful here + if i == 1: + params_0, params_1 = params_1, params_0 + xp_args, xp_population_args = xp_population_args, xp_args + runner_state_0, runner_state_1 = runner_state_1, runner_state_0 + + with jax.disable_jit(args.render_mode is not None): + student_model = models.make( + env_name=xp_args.env_name, + model_name=xp_args.student_model_name, + **xp_args.student_model_args + ) + + population_model = models.make( + env_name=xp_args.env_name, + model_name=xp_population_args.student_model_name, + **xp_population_args.student_model_args + ) + + # We force EvalRunner to select all params, since we've already + # extracted the relevant agent indices. + if "Overcooked" in args.env_names: + from minimax.runners_ma import EvalRunnerHeterogenous + + pop = AgentPopHeterogenous( + agent_0=agents.MAPPOAgent( + actor=student_model, critic=None), + agent_1=agents.MAPPOAgent( + actor=population_model, critic=None), + n_agents=len(agent_idxs) + ) + else: + raise ValueError("Unknown environment.") + + runner = EvalRunnerHeterogenous( + pop=pop, + env_names=args.env_names, + env_kwargs=xp_args.eval_env_args, + n_episodes=args.n_episodes, + render_mode=args.render_mode, + agent_idxs='*' + ) + + rng = jax.random.PRNGKey(args.seed) + _eval_stats = runner.run(rng, params_0, params_1) + + eval_stats = {} + for k, v in _eval_stats.items(): + prefix_match = re.match(r'^eval/(a[0-9]+):.*', k) + if prefix_match is not None: + prefix = prefix_match.groups()[0] + _idx = int(prefix.lstrip('a').rstrip(':')) + idx = agent_idxs[_idx] + new_prefix = f'a{idx}' + new_k = k.replace(prefix, new_prefix) + eval_stats[new_k] = v + else: + eval_stats[k] = v + + for k, v in eval_stats.items(): + all_eval_stats[k].append(float(v)) + all_eval_stats[k+":"+population[str(agent_id)].split( + "/")[-1][:-4]].append(float(v)) + + pbar.update(1) + + pbar.close() + + aggregate_eval_stats = {} + for k, v in all_eval_stats.items(): + max = np.max(all_eval_stats[k]) + mean = np.mean(all_eval_stats[k]) + if ":test_return:" in k: + print(f"k {k}, v {v}") + if len(all_eval_stats[k]) > 1: + sem = spstats.sem(all_eval_stats[k]) + else: + sem = 0.0 + aggregate_eval_stats[k] = f'{mean: 0.4}+/-{sem: 0.4} (max: {max: 0.4})' + + _min = np.min(all_eval_stats[k]) + aggregate_eval_stats[f'min:{k}'] = f'{_min: 0.4}' + + logger = HumanOutputFormat(sys.stdout) + logger.writekvs(aggregate_eval_stats) + + if args.results_fname is not None: + if args.results_fname.strip('"') == '*': + results_fname = args.xpid_prefix or args.xpid + else: + results_fname = args.results_fname + + df = pd.DataFrame.from_dict(all_eval_stats) + results_path = args.results_path + if not os.path.isabs(results_path): + results_path = os.path.join( + os.path.abspath(__file__), results_path) + results_path = os.path.join(results_path, f'{results_fname}.csv') + df.to_csv(results_path, index=False) + print(f'Saved results to {results_path}') diff --git a/src/minimax/evaluate_baseline_against_population.py b/src/minimax/evaluate_baseline_against_population.py new file mode 100644 index 0000000..f26f7d5 --- /dev/null +++ b/src/minimax/evaluate_baseline_against_population.py @@ -0,0 +1,319 @@ +import os +import json +import re +import sys +from collections import defaultdict +from typing import Tuple + +import chex +import numpy as np +import pandas as pd +import scipy.stats as spstats +import jax +import jax.numpy as jnp +import flax.linen as nn +from tqdm import tqdm + +from minimax.util.parsnip import Parsnip +from minimax.util.checkpoint import load_pkl_object, load_config +from minimax.util.loggers import HumanOutputFormat +from minimax.util.rl import AgentPopHeterogenous +import minimax.models as models +import minimax.agents as agents + + +class FixedModel(nn.Module): + """Useful as a model that acts randomly or always takes a certain action. + We use it to establish a baseline for cooperation. + """ + is_random: bool = False + always_pick_action: int = None + num_actions: int = 6 + + def setup(self): + super().__init__() + + def __call__(self, x, carry=None, reset=None): + if self.is_random: + # Same logits for all actions + logits = jnp.ones((x.shape[0], self.num_actions)) + return logits, carry + # Logits one for one action always + logits = jnp.zeros((x.shape[0], self.num_actions)) + return logits.at[:, self.always_pick_action].set(jnp.inf), carry + + def initialize_carry( + self, + rng: chex.PRNGKey, + batch_dims: Tuple[int] = ()) -> Tuple[chex.ArrayTree, chex.ArrayTree]: + """Initialize hidden state of LSTM.""" + return None + + @property + def is_recurrent(self): + return False + + +parser = Parsnip() + +# ==== Define top-level arguments +parser.add_argument( + '--seed', + type=int, + default=1, + help='Random seed.') +parser.add_argument( + '--population_json', + type=str, + default=None, + help='Path to population json file.') +parser.add_argument( + '--log_dir', + type=str, + default='~/logs/minimax', + help='Log directory containing experiment dirs.') +parser.add_argument( + '--env_names', + type=str, + help='csv of evaluation environments.') +parser.add_argument( + '--n_episodes', + type=int, + default=1, + help='Number of evaluation episodes.') +parser.add_argument( + '--agent_idxs', + type=str, + default='*', + help="Indices of agents to evaluate. '*' indicates all.") +parser.add_argument( + '--render_mode', + type=str, + nargs='?', + const=True, + default=None, + help='Visualize episodes.') +parser.add_argument( + '--results_path', + type=str, + default='results/', + help='Results dir.') +parser.add_argument( + '--results_fname', + type=str, + default=None, + help='Results filename (without .csv).') +parser.add_argument( + '--is_random', + type=str, + nargs='?', + const=True, + default=None, + help='Random fixed agent.') + +if __name__ == '__main__': + """ + Usage: + python -m eval \ + --xpid= \ + --env_names="Maze-SixteenRooms" \ + --n_episodes=100 \ + --agent_idxs=0 + """ + args = parser.parse_args() + + log_dir_path = os.path.expandvars(os.path.expanduser(args.log_dir)) + + xpid = args.xpid + + population_json_path = args.population_json + + with open(population_json_path, 'r') as f: + population = json.load(f) + + population_size = int(population["population_size"]) + + pbar = tqdm(total=population_size*2) + + all_eval_stats = defaultdict(list) + for agent_id in range(1, population_size+1): + # xpid_dir_path = os.path.join(log_dir_path, xpid) + # checkpoint_path = os.path.join( + # xpid_dir_path, f'{args.checkpoint_name}.pkl') + # meta_path = os.path.join(xpid_dir_path, f'meta.json') + + other_agent_checkpoint_path = f"{os.getcwd()}/{population[str(agent_id)]}" + other_agent_meta_path = f"{os.getcwd()}/{population[f'{agent_id}_meta']}" + + # Load checkpoint info + # if not os.path.exists(meta_path): + # print(f'Configuration at {meta_path} does not exist. Skipping...') + # continue + + if not os.path.exists(other_agent_meta_path): + raise ValueError(f"Did not find: {other_agent_meta_path}") + + # if not os.path.exists(checkpoint_path): + # print( + # f'Checkpoint path {checkpoint_path} does not exist. Skipping...') + # continue + + if not os.path.exists(other_agent_checkpoint_path): + raise ValueError(f"Did not find: {other_agent_checkpoint_path}") + + # xp_args = load_config(meta_path) + + xp_population_args = load_config(other_agent_meta_path) + + agent_idxs = args.agent_idxs + if agent_idxs == '*': + agent_idxs = np.arange( + xp_population_args.train_runner_args.n_students) + else: + agent_idxs = \ + np.array([int(x) for x in agent_idxs.split(',')]) + assert np.max(agent_idxs) <= xp_population_args.train_runner_args.n_students, \ + 'Agent index is out of bounds.' + + # runner_state_0 = load_pkl_object(checkpoint_path) + # if "params" in runner_state_0[1].keys(): + # params_0 = runner_state_0[1]['params'] + # elif "actor_params" in runner_state_0[1].keys(): + # params_0 = runner_state_0[1]['actor_params'] + # else: + # raise ValueError("No params found in checkpoint.") + + # params_0 = jax.tree_util.tree_map( + # lambda x: jnp.take(x, indices=agent_idxs, axis=0), + # params_0 + # ) + + xp_args_other = load_config(other_agent_meta_path) + + runner_state_1 = load_pkl_object(other_agent_checkpoint_path) + if "params" in runner_state_1[1].keys(): + params_1 = runner_state_1[1]['params'] + elif "actor_params" in runner_state_1[1].keys(): + params_1 = runner_state_1[1]['actor_params'] + else: + raise ValueError("No params found in checkpoint.") + + params_1 = jax.tree_util.tree_map( + lambda x: jnp.take(x, indices=agent_idxs, axis=0), + params_1 + ) + + # We use models without parameters + # {'params': {'place_holder': jnp.zeros(1,)}} + params_0 = params_1.copy() + + for i in range(2): + # Swap params and runner states + # Bit finicky be careful here + + with jax.disable_jit(args.render_mode is not None): + if args.is_random: + student_model = FixedModel( + is_random=True, + num_actions=6 + ) + else: + student_model = FixedModel( + is_random=False, + always_pick_action=4, # Stay = 4 + num_actions=6 + ) + + population_model = models.make( + env_name=xp_population_args.env_name, + model_name=xp_population_args.student_model_name, + **xp_population_args.student_model_args + ) + + if i == 1: + params_0, params_1 = params_1, params_0 + student_model, population_model = population_model, student_model + # xp_args, xp_population_args = xp_population_args, xp_args + # runner_state_0, runner_state_1 = runner_state_1, runner_state_0 + + # We force EvalRunner to select all params, since we've already + # extracted the relevant agent indices. + if "Overcooked" in args.env_names: + from minimax.runners_ma import EvalRunnerHeterogenous + + pop = AgentPopHeterogenous( + agent_0=agents.MAPPOAgent( + actor=student_model, critic=None), + agent_1=agents.MAPPOAgent( + actor=population_model, critic=None), + n_agents=len(agent_idxs) + ) + else: + raise ValueError("Unknown environment.") + + runner = EvalRunnerHeterogenous( + pop=pop, + env_names=args.env_names, + env_kwargs=xp_population_args.eval_env_args, + n_episodes=args.n_episodes, + render_mode=args.render_mode, + agent_idxs='*' + ) + + rng = jax.random.PRNGKey(args.seed) + _eval_stats = runner.run(rng, params_0, params_1) + + eval_stats = {} + for k, v in _eval_stats.items(): + prefix_match = re.match(r'^eval/(a[0-9]+):.*', k) + if prefix_match is not None: + prefix = prefix_match.groups()[0] + _idx = int(prefix.lstrip('a').rstrip(':')) + idx = agent_idxs[_idx] + new_prefix = f'a{idx}' + new_k = k.replace(prefix, new_prefix) + eval_stats[new_k] = v + else: + eval_stats[k] = v + + for k, v in eval_stats.items(): + all_eval_stats[k].append(float(v)) + + pbar.update(1) + + pbar.close() + + aggregate_eval_stats = {} + for k, v in all_eval_stats.items(): + max = np.max(all_eval_stats[k]) + mean = np.mean(all_eval_stats[k]) + if ":test_return:" in k: + print(f"k {k}, v {v}") + if ":test_solved_rate:" in k: + print(f"k {k}, v {v}") + if len(all_eval_stats[k]) > 1: + sem = spstats.sem(all_eval_stats[k]) + else: + sem = 0.0 + aggregate_eval_stats[k] = f'{mean: 0.4}+/-{sem: 0.4} (max: {max: 0.4})' + + _min = np.min(all_eval_stats[k]) + aggregate_eval_stats[f'min:{k}'] = f'{_min: 0.4}' + + logger = HumanOutputFormat(sys.stdout) + logger.writekvs(aggregate_eval_stats) + + if args.results_fname is not None: + if args.results_fname.strip('"') == '*': + results_fname = args.xpid_prefix or args.xpid + else: + results_fname = args.results_fname + + df = pd.DataFrame.from_dict(all_eval_stats) + results_path = args.results_path + if not os.path.isabs(results_path): + results_path = os.path.join( + os.path.abspath(__file__), results_path) + results_path = os.path.join(results_path, f'{results_fname}.csv') + df.to_csv(results_path, index=False) + print(f'Saved results to {results_path}') diff --git a/src/minimax/evaluate_from_pckl.py b/src/minimax/evaluate_from_pckl.py new file mode 100644 index 0000000..6dd51a7 --- /dev/null +++ b/src/minimax/evaluate_from_pckl.py @@ -0,0 +1,264 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import re +import sys +from collections import defaultdict + +import numpy as np +import pandas as pd +import scipy.stats as spstats +import jax +import jax.numpy as jnp +from tqdm import tqdm + +from minimax.util.parsnip import Parsnip +from minimax.util.checkpoint import load_pkl_object, load_config +from minimax.util.loggers import HumanOutputFormat +from minimax.util.rl import AgentPopHeterogenous +import minimax.models as models +import minimax.agents as agents + + +parser = Parsnip() + +# ==== Define top-level arguments +parser.add_argument( + '--seed', + type=int, + default=1, + help='Random seed.') +parser.add_argument( + '--pckl_path', + type=str, + default=None, + help='Path to population json file.') +parser.add_argument( + '--meta_path', + type=str, + default=None, + help='Path to population json file.') +parser.add_argument( + '--log_dir', + type=str, + default='~/logs/minimax', + help='Log directory containing experiment dirs.') +parser.add_argument( + '--other_pckl_path', + type=str, + default=None, + help='Path to population json file.') +parser.add_argument( + '--other_meta_path', + type=str, + default=None, + help='Path to population json file.') +parser.add_argument( + '--env_names', + type=str, + help='csv of evaluation environments.') +parser.add_argument( + '--n_episodes', + type=int, + default=1, + help='Number of evaluation episodes.') +parser.add_argument( + '--agent_idxs', + type=str, + default='*', + help="Indices of agents to evaluate. '*' indicates all.") +parser.add_argument( + '--render_mode', + type=str, + nargs='?', + const=True, + default=None, + help='Visualize episodes.') +parser.add_argument( + '--results_path', + type=str, + default='results/', + help='Results dir.') +parser.add_argument( + '--results_fname', + type=str, + default=None, + help='Results filename (without .csv).') + + +if __name__ == '__main__': + """ + Usage: + python -m eval \ + --xpid= \ + --env_names="Maze-SixteenRooms" \ + --n_episodes=100 \ + --agent_idxs=0 + """ + args = parser.parse_args() + + # log_dir_path = os.path.expandvars(os.path.expanduser(args.log_dir)) + + all_eval_stats = defaultdict(list) + # xpid_dir_path = os.path.join(log_dir_path, xpid) + checkpoint_path = args.pckl_path + meta_path = args.meta_path + + other_agent_checkpoint_path = args.other_pckl_path + other_agent_meta_path = args.other_meta_path + + # Load checkpoint info + if not os.path.exists(meta_path): + print(f'Configuration at {meta_path} does not exist. Skipping...') + + if not os.path.exists(other_agent_meta_path): + raise ValueError(f"Did not find: {other_agent_meta_path}") + + if not os.path.exists(checkpoint_path): + print( + f'Checkpoint path {checkpoint_path} does not exist. Skipping...') + + if not os.path.exists(other_agent_checkpoint_path): + raise ValueError(f"Did not find: {other_agent_checkpoint_path}") + + xp_args = load_config(meta_path) + + xp_other_args = load_config(other_agent_meta_path) + + agent_idxs = args.agent_idxs + if agent_idxs == '*': + agent_idxs = np.arange(xp_args.train_runner_args.n_students) + else: + agent_idxs = \ + np.array([int(x) for x in agent_idxs.split(',')]) + assert np.max(agent_idxs) <= xp_args.train_runner_args.n_students, \ + 'Agent index is out of bounds.' + + runner_state_0 = load_pkl_object(checkpoint_path) + if "params" in runner_state_0[1].keys(): + params_0 = runner_state_0[1]['params'] + elif "actor_params" in runner_state_0[1].keys(): + params_0 = runner_state_0[1]['actor_params'] + else: + raise ValueError("No params found in checkpoint.") + + params_0 = jax.tree_util.tree_map( + lambda x: jnp.take(x, indices=agent_idxs, axis=0), + params_0 + ) + + xp_args_other = load_config(other_agent_meta_path) + + runner_state_1 = load_pkl_object(other_agent_checkpoint_path) + if "params" in runner_state_1[1].keys(): + params_1 = runner_state_1[1]['params'] + elif "actor_params" in runner_state_1[1].keys(): + params_1 = runner_state_1[1]['actor_params'] + else: + raise ValueError("No params found in checkpoint.") + + params_1 = jax.tree_util.tree_map( + lambda x: jnp.take(x, indices=agent_idxs, axis=0), + params_1 + ) + + for i in range(2): + # Swap params and runner states + # Bit finicky be careful here + if i == 1: + params_0, params_1 = params_1, params_0 + xp_args, xp_other_args = xp_other_args, xp_args + runner_state_0, runner_state_1 = runner_state_1, runner_state_0 + + with jax.disable_jit(args.render_mode is not None): + student_model = models.make( + env_name=xp_args.env_name, + model_name=xp_args.student_model_name, + **xp_args.student_model_args + ) + + population_model = models.make( + env_name=xp_args.env_name, + model_name=xp_other_args.student_model_name, + **xp_other_args.student_model_args + ) + + # We force EvalRunner to select all params, since we've already + # extracted the relevant agent indices. + if "Overcooked" in args.env_names: + from minimax.runners_ma import EvalRunnerHeterogenous + + pop = AgentPopHeterogenous( + agent_0=agents.MAPPOAgent( + actor=student_model, critic=None), + agent_1=agents.MAPPOAgent( + actor=population_model, critic=None), + n_agents=len(agent_idxs) + ) + else: + raise ValueError("Unknown environment.") + + runner = EvalRunnerHeterogenous( + pop=pop, + env_names=args.env_names, + env_kwargs=xp_args.eval_env_args, + n_episodes=args.n_episodes, + render_mode=args.render_mode, + agent_idxs='*' + ) + + rng = jax.random.PRNGKey(args.seed) + _eval_stats = runner.run(rng, params_0, params_1) + + eval_stats = {} + for k, v in _eval_stats.items(): + prefix_match = re.match(r'^eval/(a[0-9]+):.*', k) + if prefix_match is not None: + prefix = prefix_match.groups()[0] + _idx = int(prefix.lstrip('a').rstrip(':')) + idx = agent_idxs[_idx] + new_prefix = f'a{idx}' + new_k = k.replace(prefix, new_prefix) + eval_stats[new_k] = v + else: + eval_stats[k] = v + + for k, v in eval_stats.items(): + all_eval_stats[k].append(float(v)) + + aggregate_eval_stats = {} + for k, v in all_eval_stats.items(): + max = np.max(all_eval_stats[k]) + mean = np.mean(all_eval_stats[k]) + if len(all_eval_stats[k]) > 1: + sem = spstats.sem(all_eval_stats[k]) + else: + sem = 0.0 + aggregate_eval_stats[k] = f'{mean: 0.4}+/-{sem: 0.4} (max: {max: 0.4})' + + _min = np.min(all_eval_stats[k]) + aggregate_eval_stats[f'min:{k}'] = f'{_min: 0.4}' + + logger = HumanOutputFormat(sys.stdout) + logger.writekvs(aggregate_eval_stats) + + if args.results_fname is not None: + if args.results_fname.strip('"') == '*': + results_fname = args.xpid_prefix or args.xpid + else: + results_fname = args.results_fname + + df = pd.DataFrame.from_dict(all_eval_stats) + results_path = args.results_path + if not os.path.isabs(results_path): + results_path = os.path.join( + os.path.abspath(__file__), results_path) + results_path = os.path.join(results_path, f'{results_fname}.csv') + df.to_csv(results_path, index=False) + print(f'Saved results to {results_path}') diff --git a/src/minimax/extract_fcp.py b/src/minimax/extract_fcp.py new file mode 100644 index 0000000..0efad9c --- /dev/null +++ b/src/minimax/extract_fcp.py @@ -0,0 +1,276 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import json +import glob +import re +import fnmatch +import sys +from collections import defaultdict + +import numpy as np +import pandas as pd +import scipy.stats as spstats +import jax +import jax.numpy as jnp +from tqdm import tqdm + +from minimax.util.parsnip import Parsnip +from minimax.util.checkpoint import load_pkl_object, load_config +from minimax.util.loggers import HumanOutputFormat +from minimax.util.rl import AgentPop +import minimax.models as models +import minimax.agents as agents + + +parser = Parsnip() + +# ==== Define top-level arguments +parser.add_argument( + '--seed', + type=int, + default=1, + help='Random seed.') +parser.add_argument( + '--log_dir', + type=str, + default='~/logs/minimax', + help='Log directory containing experiment dirs.') +parser.add_argument( + '--xpid', + type=str, + default='latest', + help='Experiment ID dir name for model.') +parser.add_argument( + '--xpid_prefix', + type=str, + default=None, + help='Experiment ID dir name for model.') +parser.add_argument( + '--checkpoint_name', + type=str, + default='checkpoint', + help='Name of checkpoint .pkl.') +parser.add_argument( + '--env_names', + type=str, + help='csv of evaluation environments.') +parser.add_argument( + '--n_episodes', + type=int, + default=1, + help='Number of evaluation episodes.') +parser.add_argument( + '--agent_idxs', + type=str, + default='*', + help="Indices of agents to evaluate. '*' indicates all.") +parser.add_argument( + '--render_mode', + type=str, + nargs='?', + const=True, + default=None, + help='Visualize episodes.') +parser.add_argument( + '--results_path', + type=str, + default='results/', + help='Results dir.') +parser.add_argument( + '--results_fname', + type=str, + default=None, + help='Results filename (without .csv).') + +parser.add_argument( + '--trained_seed', + type=int, + default=None, + help='Seed that the model was trained with') + +if __name__ == '__main__': + args = parser.parse_args() + + log_dir_path = os.path.expandvars(os.path.expanduser(args.log_dir)) + + xpids = [] + if args.xpid_prefix is not None: + # Get all matching xpid directories + all_xpids = fnmatch.filter(os.listdir( + log_dir_path), f"{args.xpid_prefix}*") + filter_re = re.compile('.*_[0-9]*$') + xpids = [x for x in all_xpids if filter_re.match(x)] + else: + xpids = [args.xpid] + + pbar = tqdm(total=len(xpids)) + + all_eval_stats = defaultdict(list) + for xpid in xpids: + xpid_dir_path = os.path.join(log_dir_path, xpid) + checkpoint_path = os.path.join( + xpid_dir_path, f'{args.checkpoint_name}.pkl') + meta_path = os.path.join(xpid_dir_path, f'meta.json') + + # Load checkpoint info + if not os.path.exists(meta_path): + print(f'Configuration at {meta_path} does not exist. Skipping...') + continue + + if not os.path.exists(checkpoint_path): + print( + f'Checkpoint path {checkpoint_path} does not exist. Skipping...') + continue + + xp_args = load_config(meta_path) + + agent_idxs = args.agent_idxs + if agent_idxs == '*': + agent_idxs = np.arange(xp_args.train_runner_args.n_students) + else: + agent_idxs = \ + np.array([int(x) for x in agent_idxs.split(',')]) + assert np.max(agent_idxs) <= xp_args.train_runner_args.n_students, \ + 'Agent index is out of bounds.' + + sub_checkpoint_paths = glob.glob(f"{checkpoint_path[:-4]}*.pkl") + sub_checkpoint_paths = sorted(list(sub_checkpoint_paths)) + + map_name_path = {} + map_name_params = {} + for sub_checkpoint_path in sub_checkpoint_paths: + desc = sub_checkpoint_path[len(checkpoint_path[:-4])+1:-4] + if desc == '': + desc = 'final' + runner_state = load_pkl_object(sub_checkpoint_path) + if "params" in runner_state[1].keys(): + params = runner_state[1]['params'] + elif "actor_params" in runner_state[1].keys(): + params = runner_state[1]['actor_params'] + else: + raise ValueError("No params found in checkpoint.") + + params = jax.tree_util.tree_map( + lambda x: jnp.take(x, indices=agent_idxs, axis=0), + params + ) + map_name_path[desc] = sub_checkpoint_path + map_name_params[desc] = params + + map_name_eval_stas = {} + for desc, params in map_name_params.items(): + with jax.disable_jit(args.render_mode is not None): + student_model = models.make( + env_name=xp_args.env_name, + model_name=xp_args.student_model_name, + **xp_args.student_model_args + ) + + # We force EvalRunner to select all params, since we've already + # extracted the relevant agent indices. + if "Overcooked" in args.env_names: + from minimax.runners_ma import EvalRunner + + pop = AgentPop( + agent=agents.MAPPOAgent( + actor=student_model, critic=None), + n_agents=len(agent_idxs) + ) + elif "Maze" in args.env_names: + from minimax.runners import EvalRunner + + pop = AgentPop( + agent=agents.PPOAgent(model=student_model), + n_agents=len(agent_idxs) + ) + else: + raise ValueError("Unknown environment.") + + runner = EvalRunner( + pop=pop, + env_names=args.env_names, + env_kwargs=xp_args.eval_env_args, + n_episodes=args.n_episodes, + render_mode=args.render_mode, + agent_idxs='*' + ) + + rng = jax.random.PRNGKey(args.seed) + _eval_stats = runner.run(rng, params) + + eval_stats = {} + for k, v in _eval_stats.items(): + prefix_match = re.match(r'^eval/(a[0-9]+):.*', k) + if prefix_match is not None: + prefix = prefix_match.groups()[0] + _idx = int(prefix.lstrip('a').rstrip(':')) + idx = agent_idxs[_idx] + new_prefix = f'a{idx}' + new_k = k.replace(prefix, new_prefix) + eval_stats[new_k] = v + else: + eval_stats[k] = v + + for k, v in eval_stats.items(): + all_eval_stats[k].append(float(v)) + + pbar.update(1) + + assert len( + runner.ext_env_names) == 1, "Only one at a time to avoid confusion!" + + map_name_eval_stas[desc] = eval_stats[ + f"eval/a0:test_return:{runner.ext_env_names[0]}"] + + best = max(map_name_eval_stas.items(), key=lambda x: x[1]) + mid = min(map_name_eval_stas.items(), + key=lambda x: abs(x[1]-int(best[1])/2)) + low = min(map_name_eval_stas.items(), + key=lambda x: abs(x[1]-int(best[1])/5)) + + print("\n\n------------------------------------\n\n") + print("Best: ", best) + print("Mid: ", mid) + print("Low: ", low) + + high_id = best[0] + mid_id = mid[0] + low_id = low[0] + + # Take paths from ids and cp files to /populations/fcp/seed/ + high_path = map_name_path[high_id] + mid_path = map_name_path[mid_id] + low_path = map_name_path[low_id] + + print("High path: ", high_path) + print("Mid path: ", mid_path) + print("Low path: ", low_path) + + # find seed string with SEED_*_ in xpid + seed = args.trained_seed + + target_dir = f"{os.getcwd()}/populations/fcp/{args.env_names}/{seed}/" + + print("Target dir: ", target_dir) + + if not os.path.exists(target_dir): + os.makedirs(target_dir) + + os.system(f"cp {high_path} {target_dir}high.pkl") + os.system(f"cp {mid_path} {target_dir}mid.pkl") + os.system(f"cp {low_path} {target_dir}low.pkl") + # also copy meta + os.system(f"cp {meta_path} {target_dir}meta.json") + + # make a txt file there and copy the xpid + with open(f"{target_dir}xpid.txt", "w") as f: + f.write(xpid) + + pbar.close() diff --git a/src/minimax/models/__init__.py b/src/minimax/models/__init__.py new file mode 100644 index 0000000..8579102 --- /dev/null +++ b/src/minimax/models/__init__.py @@ -0,0 +1,33 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .registration import register, make + + +from .maze import ( + GridWorldACStudentModel, + GridWorldACTeacherModel, +) + +from .overcooked import ( + ACStudentModel, + ACStudentCriticModel, + ACStudentActorModel, + ACTeacherModel, +) + +__all__ = [ + register, + make, + GridWorldACStudentModel, + GridWorldACTeacherModel, + ACStudentModel, + ACStudentCriticModel, + ACStudentActorModel, + ACTeacherModel, +] diff --git a/src/minimax/models/common.py b/src/minimax/models/common.py new file mode 100644 index 0000000..d72d1c5 --- /dev/null +++ b/src/minimax/models/common.py @@ -0,0 +1,383 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from typing import Tuple, Callable + +import numpy as np +import jax +import jax.numpy as jnp +from flax import linen as nn +import chex + +import minimax.envs as envs +from .rnn import CustomOptimizedLSTMCell + +from flax.linen.initializers import constant, orthogonal + + +def calc_gain(kind): + if kind == 'linear': + return 1.0 + elif kind == 'conv': + return 1.0 + elif kind == 'sigmoid': + return 1.0 + elif kind == 'tanh': + return np.sqrt(2) + elif kind == 'relu': + return np.sqrt(2) + elif kind == 'leaky_relu': + return np.sqrt(2/(1+0.01)) + elif kind == 'selu': + return 0.75 + elif kind == 'gelu': + return 0.75 + elif kind == 'crelu': + return np.sqrt(2) + + +def crelu(x): + return jnp.concatenate((nn.relu(x), nn.relu(-x)), axis=-1) + + +def get_activation(name): + if name == 'crelu': + return crelu + else: + return getattr(nn, name) + + +def default_bias_init(scale=1.0): + return nn.initializers.zeros + + +def init_orth(scale=1.0): + return nn.initializers.orthogonal(scale) + + +def make_fc_layers( + name=None, + n_layers=1, + hidden_dim=32, + activation=None, + kernel_init=None, + bias_init=nn.initializers.zeros_init(), + use_layernorm=False): + if kernel_init is None: + kernel_init = init_orth( + scale=calc_gain('linear') + ) + + layers = [] + for i in range(n_layers): + layer_name = None + if name: + layer_name = f'{name}_{i+1}' + + layers.append( + nn.Dense( + hidden_dim, + kernel_init=kernel_init, + bias_init=bias_init, + name=layer_name, + ) + ) + + if activation is not None: + layers.append(activation) + + if use_layernorm: + layers.append(nn.LayerNorm()) + + return nn.Sequential(layers) + + +def make_rnn( + arch='lstm', + kernel_init=init_orth(), + recurrent_kernel_init=init_orth(), + name=None): + if arch == 'lstm': + rnn = CustomOptimizedLSTMCell( + kernel_init=init_orth(), + recurrent_kernel_init=init_orth(), + name=name + ) + elif arch == 'gru': + rnn = nn.GRUCell( + kernel_init=init_orth(), + recurrent_kernel_init=init_orth(), + name=name + ) + else: + rnn = None + + return rnn + + +class RecurrentModuleBase(nn.Module): + def initialize_carry( + self, + rng: chex.PRNGKey, + batch_dims: Tuple[int] = ()) -> Tuple[chex.ArrayTree, chex.ArrayTree]: + """Initialize hidden state of LSTM.""" + if self.recurrent_arch == 'lstm': + return nn.OptimizedLSTMCell.initialize_carry( + rng, batch_dims, self.recurrent_hidden_dim + ) + elif self.recurrent_arch == 'gru': + return nn.GRUCell.initialize_carry( + rng, batch_dims, self.recurrent_hidden_dim + ) + else: + raise ValueError('Model is not recurrent.') + + @property + def is_recurrent(self): + return self.recurrent_arch is not None + + +class ScannedRNN(nn.Module): + """ + Scanned RNN. + Inputs: + carry: time-major input hidden states, LxBxH and optional + resets: Reset flags of shape LxB, where 1 indicates reset (equivalent to done==True). + """ + recurrent_arch: str = 'lstm' + recurrent_hidden_dim: int = 256 + kernel_init: Callable = init_orth() + recurrent_kernel_init: Callable = init_orth() + + @partial( + nn.scan, + variable_broadcast="params", + in_axes=0, + out_axes=0, + split_rngs={"params": False}, + ) + @nn.compact + def __call__(self, carry, step): + x, reset = step + rnn_state = carry + + # zero_carry = ScannedRNN.initialize_carry(jax.random.PRNGKey( + # 0), (*x.shape[:-1],), self.recurrent_hidden_dim, self.recurrent_arch) + # rnn_state = jax.tree_map( + # lambda x, y: jax.vmap(jax.lax.select)(reset, x, y), + # zero_carry, + # rnn_state + # ) + + rnn_state = jax.tree_map( + lambda x, y: jax.vmap(jax.lax.select)(reset, x, y), + ScannedRNN.initialize_carry( + jax.random.PRNGKey(0), (x.shape[0],), self.recurrent_hidden_dim, self.recurrent_arch), + rnn_state + ) + + rnn_kwargs = dict( + features=self.recurrent_hidden_dim, + kernel_init=self.kernel_init, + recurrent_kernel_init=self.recurrent_kernel_init, + ) + if self.recurrent_arch == 'lstm': + rnn_cell = nn.OptimizedLSTMCell( + **rnn_kwargs) # defaults to orth init + elif self.recurrent_arch == 'gru': + rnn_cell = nn.GRUCell(**rnn_kwargs) + else: + raise ValueError( + f'Unsupported recurrent_arch={self.recurrent_arch}') + + new_rnn_state, y = rnn_cell(rnn_state, x) + return new_rnn_state, y + + @staticmethod + def initialize_carry(rng, batch_dims, recurrent_hidden_dim, recurrent_arch): + init_args = (rng, (*batch_dims, recurrent_hidden_dim)) + if recurrent_arch == 'lstm': + # defaults to orth init + return nn.OptimizedLSTMCell(recurrent_hidden_dim, parent=None).initialize_carry(*init_args) + elif recurrent_arch == 'gru': + return nn.GRUCell(recurrent_hidden_dim, parent=None).initialize_carry(*init_args) + else: + raise ValueError(f'Unsupported recurrent_arch={recurrent_arch}') + + +class StateEncoderFF(nn.Module): + activation: str = "tanh" + + @nn.compact + def __call__(self, x): + if self.activation == "relu": + activation = nn.relu + elif self.activation == "tanh": + activation = nn.tanh + else: + raise ValueError('Activation not recognized.') + + x = x.reshape((*x.shape[:-3], -1)) + x = nn.Dense( + 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(x) + x = activation(x) + x = nn.LayerNorm()(x) + + x = nn.Dense( + 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(x) + x = activation(x) + x = nn.LayerNorm()(x) + + x = nn.Dense( + 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(x) + x = activation(x) + x = nn.LayerNorm()(x) + return x + + +class StateCNNBase(nn.Module): + activation: str = "tanh" + out_features: int = 32 + + @nn.compact + def __call__(self, x): + if self.activation == "relu": + activation = nn.relu + elif self.activation == "tanh": + activation = nn.tanh + else: + raise ValueError('Activation not recognized.') + + x = nn.Conv(features=32, kernel_size=(3, 3))(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(1, 1), padding="SAME") + x = activation(x) + x = nn.Conv(features=64, kernel_size=(3, 3))(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(1, 1), padding="SAME") + x = activation(x) + x = nn.Conv(features=self.out_features, kernel_size=(3, 3))(x) + x = nn.max_pool(x, window_shape=(2, 2), strides=(1, 1), padding="SAME") + x = activation(x) + return x + + +class StateEncoderCNN(nn.Module): + activation: str = "tanh" + + @nn.compact + def __call__(self, x): + if self.activation == "relu": + activation = nn.relu + elif self.activation == "tanh": + activation = nn.tanh + else: + raise ValueError('Activation not recognized.') + + x = StateCNNBase(activation=activation)(x) + x = x.reshape((*x.shape[:-3], -1)) # Flatten + + x = nn.Dense( + 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + )(x) + x = activation(x) + x = nn.LayerNorm()(x) + + # x = nn.Dense( + # 64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0) + # )(x) + # x = activation(x) + # x = nn.LayerNorm()(x) + + return x + + +class ValueHead(nn.Module): + n_hidden_layers: int = 1 + hidden_dim: int = 256 + activation: Callable = nn.tanh + kernel_init: Callable = init_orth(calc_gain('tanh')) + last_layer_kernel_init: Callable = init_orth(calc_gain('linear')) + use_layernorm: bool = False + + @nn.compact + def __call__(self, x): + + if self.n_hidden_layers > 1: + nn.Sequential([ + make_fc_layers( + n_layers=self.n_hidden_layers, + hidden_dim=self.hidden_dim, + activation=self.activation, + kernel_init=self.kernel_init, + use_layernorm=self.use_layernorm + ), + nn.Dense( + 1, + kernel_init=self.last_layer_kernel_init, + name='fc_value_final' + ) + ])(x) + return nn.Sequential([ + nn.Dense( + 1, + kernel_init=self.last_layer_kernel_init, + name='fc_value_final' + ) + ])(x) + + +class EnsembleValueHead(nn.Module): + n: int = 2 + + n_hidden_layers: int = 1 + hidden_dim: int = 256 + activation: Callable = nn.tanh + kernel_init: Callable = init_orth(calc_gain('tanh')) + last_layer_kernel_init: Callable = init_orth(calc_gain('linear')) + + @nn.compact + def __call__(self, x): + """ + Assumes x is batch + """ + VmapValue = nn.vmap( + ValueHead, + variable_axes={"params": 0}, + split_rngs={"params": True}, + in_axes=None, + out_axes=1, + axis_size=self.n, + ) + vs = VmapValue( + n_hidden_layers=self.n_hidden_layers, + hidden_dim=self.hidden_dim, + activation=self.activation, + kernel_init=self.kernel_init, + last_layer_kernel_init=self.last_layer_kernel_init + )(x) + + return vs + + +def clean_init_kwargs_prefix(prefix): + def class_decorator(cls): + old_init = cls.__init__ + + def new_init(self, *args, **kwargs): + kwargs = { + k.removeprefix(prefix): v for k, v in kwargs.items() + } + old_init(self, *args, **kwargs) + + cls.__init__ = new_init + return cls + + return class_decorator diff --git a/src/minimax/models/fast_attention.py b/src/minimax/models/fast_attention.py new file mode 100644 index 0000000..3da27e8 --- /dev/null +++ b/src/minimax/models/fast_attention.py @@ -0,0 +1,711 @@ +# coding=utf-8 +# Copyright 2023 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core Fast Attention Module for Flax. + +Implementation of the approximate fast softmax and generalized +attention mechanism leveraging structured random feature maps [RFM] techniques +and low rank decomposition of the attention matrix. +""" +# pylint: disable=invalid-name, missing-function-docstring, line-too-long + +import abc +from collections.abc import Iterable # pylint: disable=g-importing-member +import functools +from absl import logging +import jax +from jax import lax +from jax import random +import jax.numpy as jnp + +import numpy as onp + + +def nonnegative_softmax_kernel_feature_creator(data, + projection_matrix, + attention_dims_t, + batch_dims_t, + precision, + is_query, + normalize_data=True, + eps=0.0001): + """Constructs nonnegative kernel features for fast softmax attention. + + + Args: + data: input for which features are computes + projection_matrix: random matrix used to compute features + attention_dims_t: tuple of attention dimensions + batch_dims_t: tuple of batch dimensions + precision: precision parameter + is_query: predicate indicating whether input data corresponds to queries or + keys + normalize_data: predicate indicating whether data should be normalized, + eps: numerical stabilizer. + + Returns: + Random features for fast softmax attention. + """ + + if normalize_data: + # We have e^{qk^T/sqrt{d}} = e^{q_norm k_norm^T}, where + # w_norm = w * data_normalizer for w in {q,k}. + data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) + else: + data_normalizer = 1.0 + ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0]) + data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape + data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix + + data_dash = lax.dot_general( + data_normalizer * data, + data_thick_random_matrix, + (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), + (batch_dims_t, batch_dims_t)), + precision=precision) + + diag_data = jnp.square(data) + diag_data = jnp.sum(diag_data, axis=data.ndim - 1) + diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer + diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1) + + last_dims_t = (len(data_dash.shape) - 1,) + if is_query: + data_dash = ratio * ( + jnp.exp(data_dash - diag_data - + jnp.max(data_dash, axis=last_dims_t, keepdims=True)) + eps) + else: + data_dash = ratio * ( + jnp.exp(data_dash - diag_data - jnp.max( + data_dash, axis=last_dims_t + attention_dims_t, keepdims=True)) + + eps) + + return data_dash + + +def sincos_softmax_kernel_feature_creator(data, + projection_matrix, + attention_dims_t, + batch_dims_t, + precision, + normalize_data=True): + """Constructs kernel sin-cos features for fast softmax attention. + + + Args: + data: input for which features are computes + projection_matrix: random matrix used to compute features + attention_dims_t: tuple of attention dimensions + batch_dims_t: tuple of batch dimensions + precision: precision parameter + normalize_data: predicate indicating whether data should be normalized. + + Returns: + Random features for fast softmax attention. + """ + if normalize_data: + # We have: exp(qk^T/sqrt{d}) = exp(|q|^2/2sqrt{d}) * exp(|k|^2/2sqrt{d}) * + # exp(-(|q*c-k*c|^2)/2), where c = 1.0 / sqrt{sqrt{d}}. + data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) + else: + data_normalizer = 1.0 + ratio = 1.0 / jnp.sqrt(projection_matrix.shape[0]) + data_mod_shape = data.shape[0:len(batch_dims_t)] + projection_matrix.shape + data_thick_random_matrix = jnp.zeros(data_mod_shape) + projection_matrix + + data_dash = lax.dot_general( + data_normalizer * data, + data_thick_random_matrix, + (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), + (batch_dims_t, batch_dims_t)), + precision=precision) + data_dash_cos = ratio * jnp.cos(data_dash) + data_dash_sin = ratio * jnp.sin(data_dash) + data_dash = jnp.concatenate((data_dash_cos, data_dash_sin), axis=-1) + + # Constructing D_data and data^{'} + diag_data = jnp.square(data) + diag_data = jnp.sum(diag_data, axis=data.ndim - 1) + diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer + diag_data = jnp.expand_dims(diag_data, axis=data.ndim - 1) + # Additional renormalization for numerical stability + data_renormalizer = jnp.max(diag_data, attention_dims_t, keepdims=True) + diag_data -= data_renormalizer + diag_data = jnp.exp(diag_data) + data_prime = data_dash * diag_data + return data_prime + + +def generalized_kernel_feature_creator(data, projection_matrix, batch_dims_t, + precision, kernel_fn, kernel_epsilon, + normalize_data): + """Constructs kernel features for fast generalized attention. + + + Args: + data: input for which features are computes + projection_matrix: matrix used to compute features + batch_dims_t: tuple of batch dimensions + precision: precision parameter + kernel_fn: kernel function used + kernel_epsilon: additive positive term added to every feature for numerical + stability + normalize_data: predicate indicating whether data should be normalized. + + Returns: + Random features for fast generalized attention. + """ + if normalize_data: + data_normalizer = 1.0 / (jnp.sqrt(jnp.sqrt(data.shape[-1]))) + else: + data_normalizer = 1.0 + if projection_matrix is None: + return kernel_fn(data_normalizer * data) + kernel_epsilon + else: + data_mod_shape = data.shape[0:len( + batch_dims_t)] + projection_matrix.shape + data_thick_random_matrix = jnp.zeros( + data_mod_shape) + projection_matrix + data_dash = lax.dot_general( + data_normalizer * data, + data_thick_random_matrix, + (((data.ndim - 1,), (data_thick_random_matrix.ndim - 1,)), + (batch_dims_t, batch_dims_t)), + precision=precision) + data_prime = kernel_fn(data_dash) + kernel_epsilon + return data_prime + + +def make_fast_softmax_attention(qkv_dim, + renormalize_attention=True, + numerical_stabilizer=0.000001, + nb_features=256, + ortho_features=True, + ortho_scaling=0.0, + redraw_features=True, + unidirectional=False, + nonnegative_features=True, + lax_scan_unroll=1): + """Construct a fast softmax attention method.""" + logging.info( + 'Fast softmax attention: %s features and orthogonal=%s, renormalize=%s', + nb_features, ortho_features, renormalize_attention) + if ortho_features: + matrix_creator = functools.partial( + GaussianOrthogonalRandomMatrix, + nb_features, + qkv_dim, + scaling=ortho_scaling) + else: + matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, + nb_features, qkv_dim) + if nonnegative_features: + + def kernel_feature_creator(data, + projection_matrix, + attention_dims_t, + batch_dims_t, + precision, + is_query, + normalize_data=True): + return nonnegative_softmax_kernel_feature_creator( + data, projection_matrix, attention_dims_t, batch_dims_t, precision, + is_query, normalize_data, numerical_stabilizer) + else: + + def kernel_feature_creator(data, + projection_matrix, + attention_dims_t, + batch_dims_t, + precision, + is_query, + normalize_data=True): + del is_query + return sincos_softmax_kernel_feature_creator(data, projection_matrix, + attention_dims_t, + batch_dims_t, precision, + normalize_data) + + attention_fn = FastAttentionviaLowRankDecomposition( + matrix_creator, + kernel_feature_creator, + renormalize_attention=renormalize_attention, + numerical_stabilizer=numerical_stabilizer, + redraw_features=redraw_features, + unidirectional=unidirectional, + lax_scan_unroll=lax_scan_unroll).dot_product_attention + return attention_fn + + +def make_fast_generalized_attention(qkv_dim, + renormalize_attention=True, + numerical_stabilizer=0.0, + nb_features=256, + features_type='deterministic', + kernel_fn=jax.nn.relu, + kernel_epsilon=0.001, + redraw_features=False, + unidirectional=False, + lax_scan_unroll=1): + """Construct a fast generalized attention menthod.""" + logging.info('Fast generalized attention.: %s features and renormalize=%s', + nb_features, renormalize_attention) + if features_type == 'ortho': + matrix_creator = functools.partial( + GaussianOrthogonalRandomMatrix, nb_features, qkv_dim, scaling=False) + elif features_type == 'iid': + matrix_creator = functools.partial(GaussianUnstructuredRandomMatrix, + nb_features, qkv_dim) + elif features_type == 'deterministic': + matrix_creator = None + else: + raise ValueError('Unknown feature value type') + + def kernel_feature_creator(data, + projection_matrix, + attention_dims_t, + batch_dims_t, + precision, + is_query, + normalize_data=False): + del attention_dims_t + del is_query + return generalized_kernel_feature_creator(data, projection_matrix, + batch_dims_t, precision, + kernel_fn, kernel_epsilon, + normalize_data) + + attention_fn = FastAttentionviaLowRankDecomposition( + matrix_creator, + kernel_feature_creator, + renormalize_attention=renormalize_attention, + numerical_stabilizer=numerical_stabilizer, + redraw_features=redraw_features, + unidirectional=unidirectional, + lax_scan_unroll=lax_scan_unroll).dot_product_attention + return attention_fn + + +class RandomMatrix(object): + r"""Abstract class providing a method for constructing 2D random arrays. + + Class is responsible for constructing 2D random arrays. + """ + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def get_2d_array(self): + raise NotImplementedError('Abstract method') + + +class GaussianUnstructuredRandomMatrix(RandomMatrix): + + def __init__(self, nb_rows, nb_columns, key): + self.nb_rows = nb_rows + self.nb_columns = nb_columns + self.key = key + + def get_2d_array(self): + return random.normal(self.key, (self.nb_rows, self.nb_columns)) + + +class GaussianOrthogonalRandomMatrix(RandomMatrix): + r"""Class providing a method to create Gaussian orthogonal matrix. + + Class is responsible for constructing 2D Gaussian orthogonal arrays. + """ + + def __init__(self, nb_rows, nb_columns, key, scaling=0): + self.nb_rows = nb_rows + self.nb_columns = nb_columns + self.key = key + self.scaling = scaling + + def get_2d_array(self): + nb_full_blocks = int(self.nb_rows / self.nb_columns) + block_list = [] + rng = self.key + for _ in range(nb_full_blocks): + rng, rng_input = jax.random.split(rng) + unstructured_block = random.normal(rng_input, + (self.nb_columns, self.nb_columns)) + q, _ = jnp.linalg.qr(unstructured_block) + q = jnp.transpose(q) + block_list.append(q) + remaining_rows = self.nb_rows - nb_full_blocks * self.nb_columns + if remaining_rows > 0: + rng, rng_input = jax.random.split(rng) + unstructured_block = random.normal(rng_input, + (self.nb_columns, self.nb_columns)) + q, _ = jnp.linalg.qr(unstructured_block) + q = jnp.transpose(q) + block_list.append(q[0:remaining_rows]) + final_matrix = jnp.vstack(block_list) + + if self.scaling == 0: + multiplier = jnp.linalg.norm( + random.normal(self.key, (self.nb_rows, self.nb_columns)), axis=1) + elif self.scaling == 1: + multiplier = jnp.sqrt(float(self.nb_columns)) * \ + jnp.ones((self.nb_rows)) + else: + raise ValueError( + 'Scaling must be one of {0, 1}. Was %s' % self._scaling) + + return jnp.matmul(jnp.diag(multiplier), final_matrix) + + +class FastAttention(object): + r"""Abstract class providing a method for fast attention. + + Class is responsible for providing a method for fast + approximate attention. + """ + + __metaclass__ = abc.ABCMeta + + @abc.abstractmethod + def dot_product_attention(self, + query, + key, + value, + dtype=jnp.float32, + bias=None, + mask=None, + axis=None, + broadcast_dropout=True, + dropout_rng=None, + dropout_rate=0., + deterministic=False, + precision=None): + """Computes dot-product attention given query, key, and value. + + This is the core function for applying fast approximate dot-product + attention. It calculates the attention weights given query and key and + combines the values using the attention weights. This function supports + multi-dimensional inputs. + + + Args: + query: queries for calculating attention with shape of [batch_size, dim1, + dim2, ..., dimN, num_heads, mem_channels]. + key: keys for calculating attention with shape of [batch_size, dim1, dim2, + ..., dimN, num_heads, mem_channels]. + value: values to be used in attention with shape of [batch_size, dim1, + dim2,..., dimN, num_heads, value_channels]. + dtype: the dtype of the computation (default: float32) + bias: bias for the attention weights. This can be used for incorporating + autoregressive mask, padding mask, proximity bias. + mask: mask for the attention weights. This can be used for incorporating + autoregressive masks. + axis: axises over which the attention is applied. + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rng: JAX PRNGKey: to be used for dropout. + dropout_rate: dropout rate. + deterministic: bool, deterministic or not (to apply dropout). + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + + Returns: + Output of shape [bs, dim1, dim2, ..., dimN,, num_heads, value_channels]. + """ + raise NotImplementedError('Abstract method') + + +def _numerator(z_slice_shape, precision, unroll=1): + + def fwd(qs, ks, vs): + + def body(p, qkv): + (q, k, v) = qkv + p += jnp.einsum('...m,...d->...md', k, v, precision=precision) + X_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision) + return p, X_slice + + init_value = jnp.zeros(z_slice_shape) + p, W = lax.scan(body, init_value, (qs, ks, vs), unroll=unroll) + return W, (p, qs, ks, vs) + + def bwd(pqkv, W_ct): + + def body(carry, qkv_xct): + p, p_ct = carry + q, k, v, x_ct = qkv_xct + q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision) + p_ct += jnp.einsum('...d,...m->...md', x_ct, + q, precision=precision) + k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision) + v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision) + p -= jnp.einsum('...m,...d->...md', k, v, precision=precision) + return (p, p_ct), (q_ct, k_ct, v_ct) + + p, qs, ks, vs = pqkv + _, (qs_ct, ks_ct, vs_ct) = lax.scan( + body, (p, jnp.zeros_like(p)), (qs, ks, vs, W_ct), + reverse=True, + unroll=unroll) + return qs_ct, ks_ct, vs_ct + + @jax.custom_vjp + def _numerator_impl(qs, ks, vs): + W, _ = fwd(qs, ks, vs) + return W + + _numerator_impl.defvjp(fwd, bwd) + + return _numerator_impl + + +def _denominator(t_slice_shape, precision, unroll=1): + + def fwd(qs, ks): + + def body(p, qk): + q, k = qk + p += k + x = jnp.einsum('...m,...m->...', q, p, precision=precision) + return p, x + + p = jnp.zeros(t_slice_shape) + p, R = lax.scan(body, p, (qs, ks), unroll=unroll) + return R, (qs, ks, p) + + def bwd(qkp, R_ct): + + def body(carry, qkx): + p, p_ct = carry + q, k, x_ct = qkx + q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision) + p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision) + k_ct = p_ct + p -= k + return (p, p_ct), (q_ct, k_ct) + + qs, ks, p = qkp + _, (qs_ct, ks_ct) = lax.scan( + body, (p, jnp.zeros_like(p)), (qs, ks, R_ct), + reverse=True, + unroll=unroll) + return (qs_ct, ks_ct) + + @jax.custom_vjp + def _denominator_impl(qs, ks): + R, _ = fwd(qs, ks) + return R + + _denominator_impl.defvjp(fwd, bwd) + + return _denominator_impl + + +class FastAttentionviaLowRankDecomposition(FastAttention): + r"""Class providing a method for fast attention via low rank decomposition. + + Class is responsible for providing a method for fast + dot-product attention with the use of low rank decomposition (e.g. with + random feature maps). + """ + + def __init__(self, + matrix_creator, + kernel_feature_creator, + renormalize_attention, + numerical_stabilizer, + redraw_features, + unidirectional, + lax_scan_unroll=1): # For optimal GPU performance, set to 16. + rng = random.PRNGKey(0) + self.matrix_creator = matrix_creator + self.projection_matrix = self.draw_weights(rng) + self.kernel_feature_creator = kernel_feature_creator + self.renormalize_attention = renormalize_attention + self.numerical_stabilizer = numerical_stabilizer + self.redraw_features = redraw_features + self.unidirectional = unidirectional + self.lax_scan_unroll = lax_scan_unroll + + def draw_weights(self, key): + if self.matrix_creator is None: + return None + matrixrng, _ = random.split(key) + projection_matrix = self.matrix_creator(key=matrixrng).get_2d_array() + return projection_matrix + + def dot_product_attention(self, + query, + key, + value, + dtype=jnp.float32, + bias=None, + mask=None, + axis=None, + broadcast_dropout=True, + dropout_rng=None, + dropout_rate=0., + deterministic=False, + precision=None): + + assert key.shape[:-1] == value.shape[:-1] + assert (query.shape[0:1] == key.shape[0:1] and + query.shape[-1] == key.shape[-1]) + if axis is None: + axis = tuple(range(1, key.ndim - 2)) + if not isinstance(axis, Iterable): + axis = (axis,) + assert key.ndim == query.ndim + assert key.ndim == value.ndim + for ax in axis: + if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): + raise ValueError('Attention axis must be between the batch ' + 'axis and the last-two axes.') + n = key.ndim + + # Constructing projection tensor. + if self.redraw_features: + query_seed = lax.convert_element_type( + jnp.ceil(jnp.sum(query) * 10000000.0), jnp.int32) + rng = random.PRNGKey(query_seed) + self.projection_matrix = self.draw_weights(rng) + + # batch_dims is , num_heads> + batch_dims = tuple(onp.delete(range(n), axis + (n - 1,))) + # q & k -> (bs, , num_heads, , channels) + qk_perm = batch_dims + axis + (n - 1,) + k_extra_perm = axis + batch_dims + (n - 1,) + key_extra = key.transpose(k_extra_perm) + key = key.transpose(qk_perm) + query = query.transpose(qk_perm) + # v -> (bs, , num_heads, , channels) + v_perm = batch_dims + axis + (n - 1,) + value = value.transpose(v_perm) + batch_dims_t = tuple(range(len(batch_dims))) + attention_dims_t = tuple( + range(len(batch_dims), + len(batch_dims) + len(axis))) + + # Constructing tensors Q^{'} and K^{'}. + query_prime = self.kernel_feature_creator(query, self.projection_matrix, + attention_dims_t, batch_dims_t, + precision, True) + key_prime = self.kernel_feature_creator(key, self.projection_matrix, + attention_dims_t, batch_dims_t, + precision, False) + + if self.unidirectional: + index = attention_dims_t[0] + z_slice_shape = key_prime.shape[0:len(batch_dims_t)] + ( + key_prime.shape[-1],) + (value.shape[-1],) + + numerator_fn = _numerator( + z_slice_shape, precision, self.lax_scan_unroll) + W = numerator_fn( + jnp.moveaxis(query_prime, index, 0), + jnp.moveaxis(key_prime, index, 0), jnp.moveaxis(value, index, 0)) + + # Constructing W = (Q^{'}(K^{'})^{T})_{masked}V + W = jnp.moveaxis(W, 0, index) + + if not self.renormalize_attention: + # Unidirectional, not-normalized attention. + perm_inv = _invert_perm(qk_perm) + result = W.transpose(perm_inv) + return result + else: + # Unidirectional, normalized attention. + thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones( + key_extra.shape[0:len(axis)]) + + index = attention_dims_t[0] + t_slice_shape = key_prime.shape[0:len(batch_dims_t)] + ( + key_prime.shape[-1],) + denominator_fn = _denominator(t_slice_shape, precision, + self.lax_scan_unroll) + R = denominator_fn( + jnp.moveaxis(query_prime, index, 0), + jnp.moveaxis(key_prime, index, 0)) + + R = jnp.moveaxis(R, 0, index) + else: + contract_query = tuple( + range(len(batch_dims) + len(axis), + len(batch_dims) + len(axis) + 1)) + contract_z = tuple(range(len(batch_dims), len(batch_dims) + 1)) + # Constructing Z = (K^{'})^{T}V + # Z (bs, , num_heads, channels_m, channels_v) + Z = lax.dot_general( + key_prime, + value, + ((attention_dims_t, attention_dims_t), + (batch_dims_t, batch_dims_t)), + precision=precision) + # Constructing W = Q^{'}Z = Q^{'}(K^{'})^{T}V + # q (bs, , num_heads, , channels_m) + # Z (bs, , num_heads, channels_m, channels_v) + # W (bs, , num_heads, , channels_v) + W = lax.dot_general( + query_prime, + Z, ((contract_query, contract_z), (batch_dims_t, batch_dims_t)), + precision=precision) + if not self.renormalize_attention: + # Bidirectional, not-normalized attention. + perm_inv = _invert_perm(qk_perm) + result = W.transpose(perm_inv) + return result + else: + # Bidirectional, normalized attention. + thick_all_ones = jnp.zeros(key.shape[0:-1]) + jnp.ones( + key_extra.shape[0:len(axis)]) + contract_key = tuple( + range(len(batch_dims), + len(batch_dims) + len(axis))) + contract_thick_all_ones = tuple( + range(thick_all_ones.ndim - len(axis), thick_all_ones.ndim)) + # Construct T = (K^{'})^{T} 1_L + # k (bs, , num_heads, , channels) + T = lax.dot_general( + key_prime, + thick_all_ones, ((contract_key, contract_thick_all_ones), + (batch_dims_t, batch_dims_t)), + precision=precision) + + # Construct partition function: R = Q^{'} T = Q^{'}(K^{'})^{T} 1_L + # q_p (bs, , num_heads, , channs_m) + # T (bs, , num_heads, channels_m) + R = lax.dot_general( + query_prime, + T, (((query_prime.ndim - 1,), (T.ndim - 1,)), + (batch_dims_t, range(0, + len(T.shape) - 1))), + precision=precision) + + R = R + 2 * self.numerical_stabilizer * ( + jnp.abs(R) <= self.numerical_stabilizer) + R = jnp.reciprocal(R) + R = jnp.expand_dims(R, len(R.shape)) + # W (bs, , num_heads, , channels_v) + # R (bs, , num_heads, , extra_channel) + result = W * R + # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) + perm_inv = _invert_perm(qk_perm) + result = result.transpose(perm_inv) + return result + + +def _invert_perm(perm): + perm_inv = [0] * len(perm) + for i, j in enumerate(perm): + perm_inv[j] = i + return tuple(perm_inv) diff --git a/src/minimax/models/maze/__init__.py b/src/minimax/models/maze/__init__.py new file mode 100644 index 0000000..2ddb707 --- /dev/null +++ b/src/minimax/models/maze/__init__.py @@ -0,0 +1,12 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .gridworld_models import ( + GridWorldACStudentModel, + GridWorldACTeacherModel, +) \ No newline at end of file diff --git a/src/minimax/models/maze/gridworld_models.py b/src/minimax/models/maze/gridworld_models.py new file mode 100644 index 0000000..36d8e6d --- /dev/null +++ b/src/minimax/models/maze/gridworld_models.py @@ -0,0 +1,277 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Tuple + +import numpy as np +import jax +import jax.numpy as jnp +import flax.linen as nn +import chex +from tensorflow_probability.substrates import jax as tfp + +from minimax.models import common +from minimax.models import s5 +from minimax.models.registration import register + + +class GridWorldBasicModel(nn.Module): + """Split Actor-Critic Architecture for PPO.""" + output_dim: int = 7 + n_hidden_layers: int = 1 + hidden_dim: int = 32 + n_conv_layers: int = 1 + n_conv_filters: int = 16 + conv_kernel_size: int = 3 + n_scalar_embeddings: int = 4 + max_scalar: int = 4 + scalar_embed_dim: int = 5 + recurrent_arch: str = None + recurrent_hidden_dim: int = 256 + base_activation: str = 'relu' + head_activation: str = 'tanh' + + s5_n_blocks: int = 2 + s5_n_layers: int = 4 + s5_layernorm_pos: str = None + s5_activation: str = "half_glu1" + + value_ensemble_size: int = 1 + + def setup(self): + self.conv = nn.Sequential([ + nn.Conv( + features=self.n_conv_filters, + kernel_size=[self.conv_kernel_size,]*2, + strides=1, + kernel_init=common.init_orth( + scale=common.calc_gain(self.base_activation) + ), + padding='VALID', + name='cnn'), + common.get_activation(self.base_activation) + ]) + + if self.n_scalar_embeddings > 0: + self.fc_scalar = nn.Embed( + num_embeddings=self.n_scalar_embeddings, + features=self.scalar_embed_dim, + embedding_init=common.init_orth( + common.calc_gain('linear') + ), + name=f'fc_scalar' + ) + elif self.scalar_embed_dim > 0: + self.fc_scalar = nn.Dense( + self.scalar_embed_dim, + kernel_init=common.init_orth( + common.calc_gain('linear') + ), + name=f'fc_scalar' + ) + else: + self.fc_scalar = None + + if self.recurrent_arch is not None: + if self.recurrent_arch == 's5': + self.embed_pre_s5 = nn.Sequential([ + nn.Dense( + self.recurrent_hidden_dim, + kernel_init=common.init_orth( + common.calc_gain('linear') + ), + name=f'fc_pre_s5' + ) + ]) + self.rnn = s5.make_s5_encoder_stack( + input_dim=self.recurrent_hidden_dim, + ssm_state_dim=self.recurrent_hidden_dim, + n_blocks=self.s5_n_blocks, + n_layers=self.s5_n_layers, + activation=self.s5_activation, + layernorm_pos=self.s5_layernorm_pos + ) + else: + self.rnn = common.ScannedRNN( + recurrent_arch=self.recurrent_arch, + recurrent_hidden_dim=self.recurrent_hidden_dim, + kernel_init=common.init_orth(), + recurrent_kernel_init=common.init_orth() + ) + else: + self.rnn = None + + self.pi_head = nn.Sequential([ + common.make_fc_layers( + 'fc_pi', + n_layers=self.n_hidden_layers, + hidden_dim=self.hidden_dim, + activation=common.get_activation(self.head_activation), + kernel_init=common.init_orth( + common.calc_gain(self.head_activation) + ) + ), + nn.Dense( + self.output_dim, + kernel_init=nn.initializers.constant(0.01), + name=f'fc_pi_final' + ) + ]) + + value_head_kwargs = dict( + n_hidden_layers=self.n_hidden_layers, + hidden_dim=self.hidden_dim, + activation=nn.tanh, + kernel_init=common.init_orth( + common.calc_gain('tanh') + ), + last_layer_kernel_init=common.init_orth( + common.calc_gain('linear') + ) + ) + + if self.value_ensemble_size > 1: + self.v_head = common.EnsembleValueHead( + n=self.value_ensemble_size, **value_head_kwargs) + else: + self.v_head = common.ValueHead(**value_head_kwargs) + + def __call__(self, x, carry=None): + raise NotImplementedError + + def initialize_carry( + self, + rng: chex.PRNGKey, + batch_dims: Tuple[int] = ()) -> Tuple[chex.ArrayTree, chex.ArrayTree]: + """Initialize hidden state of LSTM.""" + if self.recurrent_arch is not None: + if self.recurrent_arch == 's5': + return s5.S5EncoderStack.initialize_carry( # Since conj_sym=True + rng, batch_dims, self.recurrent_hidden_dim//2, self.s5_n_layers + ) + else: + return common.ScannedRNN.initialize_carry( + rng, batch_dims, self.recurrent_hidden_dim, self.recurrent_arch) + else: + raise ValueError('Model is not recurrent.') + + @property + def is_recurrent(self): + return self.recurrent_arch is not None + + +class GridWorldACStudentModel(GridWorldBasicModel): + def __call__(self, x, carry=None, reset=None): + """ + Inputs: + x: B x h x w observations + hxs: B x hx_dim hidden states + masks: B length vector of done masks + """ + old_x = x + img = x['image'] + agent_dir = x['agent_dir'] + aux = x.get('aux') + + if self.rnn is not None: + batch_dims = img.shape[:2] + x = self.conv(img).reshape(*batch_dims, -1) + else: + batch_dims = img.shape[:1] + x = self.conv(img).reshape(*batch_dims, -1) + + if self.fc_scalar is not None: + if self.n_scalar_embeddings == 0: + agent_dir /= self.max_scalar + + scalar_emb = self.fc_scalar(agent_dir).reshape(*batch_dims, -1) + x = jnp.concatenate([x, scalar_emb], axis=-1) + + if aux is not None: + x = jnp.concatenate([x, aux], axis=-1) + + if self.rnn is not None: + if self.recurrent_arch == 's5': + x = self.embed_pre_s5(x) + carry, x = self.rnn(carry, x, reset) + else: + carry, x = self.rnn(carry, (x, reset)) + + logits = self.pi_head(x) + + v = self.v_head(x) + + return v, logits, carry + + +class GridWorldACTeacherModel(GridWorldBasicModel): + """ + Original teacher model from Dennis et al, 2020. It is identical ins + high-level spec to the student model, but with the additional fwd logic + of concatenating a noise vector. + """ + def __call__(self, x, carry=None, reset=None): + """ + Inputs: + x: B x h x w observations + hxs: B x hx_dim hidden states + masks: B length vector of done masks + """ + img = x['image'] + time = x['time'] + noise = x.get('noise') + aux = x.get('aux') + + if self.rnn is not None: + batch_dims = img.shape[:2] + x = self.conv(img).reshape(*batch_dims, -1) + else: + batch_dims = img.shape[:1] + x = self.conv(img).reshape(*batch_dims, -1) + + if self.fc_scalar is not None: + if self.n_scalar_embeddings == 0: + time /= self.max_scalar + + scalar_emb = self.fc_scalar(time).reshape(*batch_dims, -1) + x = jnp.concatenate([x, scalar_emb], axis=-1) + + if noise is not None: + noise = noise.reshape(*batch_dims, -1) + x = jnp.concatenate([x, noise], axis=-1) + + if aux is not None: + x = jnp.concatenate([x, aux], axis=-1) + + if self.rnn is not None: + if self.recurrent_arch == 's5': + x = self.embed_pre_s5(x) + carry, x = self.rnn(carry, x, reset) + else: + carry, x = self.rnn(carry, (x, reset)) + + logits = self.pi_head(x) + + v = self.v_head(x) + + return v, logits, carry + + +# Register models +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +register( + env_group_id='Maze', model_id='default_student_cnn', + entry_point=module_path + ':GridWorldACStudentModel') + +register( + env_group_id='Maze', model_id='default_teacher_cnn', + entry_point=module_path + ':GridWorldACTeacherModel') diff --git a/src/minimax/models/moe.py b/src/minimax/models/moe.py new file mode 100644 index 0000000..06eb461 --- /dev/null +++ b/src/minimax/models/moe.py @@ -0,0 +1,162 @@ +from typing import Any + +import einops +import jax +import flax.linen as nn +import jax.numpy as jnp + +from flax.linen.initializers import constant, orthogonal + +from minimax.models.common import StateCNNBase + + + +class MultiExpertLayer(nn.Module): + + in_features: int + out_features: int + num_experts: int + slots_per_expert: int + + def setup(self) -> None: + + self.weight = self.param( + "weight", + nn.initializers.xavier_uniform(), + (self.num_experts, self.in_features, self.out_features), + ) + + self.bias = self.param( + "bias", + nn.initializers.xavier_uniform(), + (self.num_experts, self.out_features), + ) + + def __call__(self, x) -> Any: + x = einops.einsum(x, self.weight, "b n ... d1, n d1 d2 -> b n ... d2") + + if self.bias is not None: + # NOTE: When used with 'SoftMoE' the inputs to 'MultiExpertLayer' will + # always be 4-dimensional. But it's easy enough to generalize for 3D + # inputs as well, so I decided to include that here. + # if x.ndim == 3: + # bias = einops.rearrange(self.bias, "n d -> () n d") + if x.ndim == 4: + bias = einops.rearrange(self.bias, "n d -> () n () d") + else: + raise ValueError( + f"Expected input to have 3 or 4 dimensions, but got {x.ndim}" + ) + x = x + bias + + return x + + +class SoftMoE(nn.Module): + + in_features: int + out_features: int + num_experts: int + slots_per_expert: int + + def setup(self) -> None: + self.experts = MultiExpertLayer( + in_features=self.in_features, + out_features=self.out_features, + num_experts=self.num_experts, + slots_per_expert=self.slots_per_expert, + + ) + self.phi = self.param( + 'phi', + nn.initializers.xavier_uniform(), + (self.in_features, self.num_experts, self.slots_per_expert), + ) + + def __call__(self, x) -> Any: + logits = einops.einsum(x, self.phi, "b m d, d n p -> b m n p") + dispatch_weights = nn.softmax(logits, axis=1) + # dispatch_weights = logits.softmax(dim=1) # denoted 'D' in the paper + # NOTE: The 'torch.softmax' function does not support multiple values for the + # 'dim' argument (unlike jax), so we are forced to flatten the last two dimensions. + # Then, we rearrange the Tensor into its original shape. + combine_weights = nn.softmax(logits, axis=(-2,-1)) + # combine_weights = einops.rearrange( + # nn.softmax(logits.reshape((*logits.shape[:-2], -1)), axis=-1), + # # logits.flatten(start_dim=2).softmax(dim=-1), + # "b m (n p) -> b m n p", + # n=self.num_experts, + # ) + + # NOTE: To save memory, I don't rename the intermediate tensors Y, Ys, Xs. + # Instead, I just overwrite the 'x' variable. The names from the paper are + # included in a comment for each line below. + x = einops.einsum( + x, dispatch_weights, "b m d, b m n p -> b n p d") # Xs + x = self.experts(x) # Ys + x = einops.einsum(x, combine_weights, "b n p d, b m n p -> b m d") # Y + + return x + + +class MoE(nn.Module): + activation: str = "tanh" + state_encoder_module: nn.Module = StateCNNBase + hiddem_dim: int = 64 + recurrent_arch: str = None + + def setup(self) -> None: + self.state_encoder = self.state_encoder_module( + activation=self.activation) + + self.moe = SoftMoE( + in_features=self.state_encoder.out_features, + out_features=self.hiddem_dim, + num_experts=4, + slots_per_expert=32, + ) + + self.proj = nn.Dense( + self.hiddem_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0) + ) + self.proj_layer_norm = nn.LayerNorm() + + + + def __call__(self, x) -> Any: + if self.activation == "relu": + activation = nn.relu + elif self.activation == "tanh": + activation = nn.tanh + else: + raise ValueError('Activation not recognized.') + + input_shape = x.shape + state_embedding = self.state_encoder(x) + + if len(input_shape) == 5: + a, n = state_embedding.shape[:2] + state_embedding = einops.rearrange(state_embedding, "a n ... -> (a n) ...", a=a, n=n) + + state_embedding = einops.rearrange(state_embedding, "... w h c -> ... (w h) c") + state_embedding = self.moe(state_embedding) + + state_embedding = x.reshape((*state_embedding.shape[:-2], -1)) + + if len(input_shape) == 5: + state_embedding = einops.rearrange(state_embedding, "(a n) ... -> a n ...", a=a, n=n) + + state_embedding = self.proj(state_embedding) + state_embedding = self.proj_layer_norm(state_embedding) + state_embedding = activation(state_embedding) + + return state_embedding + + +if __name__ == '__main__': + rng = jax.random.PRNGKey(30) + obs = jnp.zeros((200,6,9,26)) + moe = MoE(action_dim=6) + params = moe.init(rng, obs) + logits, _ = moe.apply(params, obs) + jax.debug.breakpoint() \ No newline at end of file diff --git a/src/minimax/models/overcooked/__init__.py b/src/minimax/models/overcooked/__init__.py new file mode 100644 index 0000000..5d61337 --- /dev/null +++ b/src/minimax/models/overcooked/__init__.py @@ -0,0 +1,14 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .models import ( + ACStudentModel, + ACStudentActorModel, + ACStudentCriticModel, + ACTeacherModel, +) diff --git a/src/minimax/models/overcooked/models.py b/src/minimax/models/overcooked/models.py new file mode 100644 index 0000000..7bbc8b0 --- /dev/null +++ b/src/minimax/models/overcooked/models.py @@ -0,0 +1,536 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Tuple, Sequence + +import einops +import numpy as np +import jax +import jax.numpy as jnp +import flax.linen as nn +import chex +from tensorflow_probability.substrates import jax as tfp + +from minimax.models import common +from minimax.models import s5 +from minimax.models import transformer +from minimax.models.registration import register +from minimax.models.moe import SoftMoE + +from flax.linen.initializers import constant, orthogonal + + +class BasicModel(nn.Module): + """Split Actor-Critic Architecture for PPO.""" + output_dim: int = 6 + n_hidden_layers: int = 1 + hidden_dim: int = 32 + n_conv_layers: int = 1 + n_conv_filters: int = 16 + conv_encoder: bool = True + conv_kernel_size: int = 3 + n_scalar_embeddings: int = 4 + max_scalar: int = 4 + scalar_embed_dim: int = 5 + recurrent_arch: str = None + recurrent_hidden_dim: int = 256 + base_activation: str = 'relu' + head_activation: str = 'tanh' + + s5_n_blocks: int = 2 + s5_n_layers: int = 4 + s5_layernorm_pos: str = None + s5_activation: str = "half_glu1" + + transf_init_scale: float = 0.1 + transf_num_layers: int = 2 + transf_num_heads: int = 4 + transf_dropout_prob: float = 0.0 + transf_deterministic: bool = True + transf_return_embeddings: bool = False + transf_use_fast_attention: bool = False + transf_gated: bool = True + + is_soft_moe: bool = False + soft_moe_num_experts: int = 4 + soft_moe_num_slots: int = 32 + + value_ensemble_size: int = 1 + + def setup(self): + + if self.conv_encoder: + conv_list = [] + for i, feat in enumerate([32, 64, 32]): + # padding = "SAME" if i < self.n_conv_layers - 2 else "VALID" + conv_list.append( + nn.Conv( + features=feat, + kernel_size=[self.conv_kernel_size,]*2, + strides=1, + kernel_init=common.init_orth( + scale=common.calc_gain(self.base_activation) + ), + bias_init=common.default_bias_init(), + padding=((1, 1), (1, 1)), # padding, # 'SAME', + name=f'cnn_{i}' + ) + ) + conv_list.append( + common.get_activation(self.base_activation) + ) + + self.conv = nn.Sequential(conv_list) + self.after_conv = common.make_fc_layers( + n_layers=self.n_hidden_layers, + hidden_dim=self.hidden_dim, + activation=common.get_activation(self.base_activation), + kernel_init=common.init_orth( + scale=common.calc_gain(self.base_activation) + ), + bias_init=common.default_bias_init(), + use_layernorm=True, + ) + self.linear_encoder = None + else: + self.conv = None + self.linear_encoder = nn.Sequential([ + nn.Dense( + self.hidden_dim, + kernel_init=common.init_orth( + common.calc_gain('linear') + ), + name=f'fc_linear' + ), + common.get_activation(self.base_activation), + nn.LayerNorm(name='ln_linear'), + ]) + + if self.is_soft_moe: + self.moe = SoftMoE( + in_features=self.n_conv_filters, + out_features=self.hidden_dim, + num_experts=self.soft_moe_num_experts, # 4, + slots_per_expert=self.soft_moe_num_slots, # 32, + ) + + if self.n_scalar_embeddings > 0: + self.fc_scalar = nn.Embed( + num_embeddings=self.n_scalar_embeddings, + features=self.scalar_embed_dim, + embedding_init=common.init_orth( + common.calc_gain('linear') + ), + name=f'fc_scalar' + ) + elif self.scalar_embed_dim > 0: + self.fc_scalar = nn.Dense( + self.scalar_embed_dim, + kernel_init=common.init_orth( + common.calc_gain('linear') + ), + name=f'fc_scalar' + ) + else: + self.fc_scalar = None + + if self.recurrent_arch is not None: + if self.recurrent_arch == 's5': + self.embed_pre_s5 = nn.Sequential([ + nn.Dense( + self.recurrent_hidden_dim, + kernel_init=common.init_orth( + common.calc_gain('linear') + ), + name=f'fc_pre_s5' + ) + ]) + self.rnn = s5.make_s5_encoder_stack( + input_dim=self.recurrent_hidden_dim, + ssm_state_dim=self.recurrent_hidden_dim, + n_blocks=self.s5_n_blocks, + n_layers=self.s5_n_layers, + activation=self.s5_activation, + layernorm_pos=self.s5_layernorm_pos + ) + elif self.recurrent_arch == 'transformer': + self.rnn = transformer.ScannedTransformer( + hidden_dim=self.recurrent_hidden_dim, + init_scale=self.transf_init_scale, + transf_num_layers=self.transf_num_layers, + transf_num_heads=self.transf_num_heads, + transf_dim_feedforward=self.recurrent_hidden_dim, + transf_dropout_prob=self.transf_transf_dropout_prob, + deterministic=self.transf_deterministic, + return_embeddings=self.transf_return_embeddings, + use_fast_attention=self.transf_use_fast_attention, + gated=self.transf_gated, + ) + else: + self.rnn = common.ScannedRNN( + recurrent_arch=self.recurrent_arch, + recurrent_hidden_dim=self.recurrent_hidden_dim, + kernel_init=common.init_orth(), + recurrent_kernel_init=common.init_orth() + ) + else: + self.rnn = None + + self.pi_head = nn.Sequential([ + # common.make_fc_layers( + # 'fc_pi', + # n_layers=self.n_hidden_layers, + # hidden_dim=self.hidden_dim, + # activation=common.get_activation(self.head_activation), + # kernel_init=common.init_orth( + # common.calc_gain(self.head_activation) + # ) + # ), + nn.Dense( + self.output_dim, + kernel_init=nn.initializers.constant(0.01), + name=f'fc_pi_final' + ) + ]) + + value_head_kwargs = dict( + n_hidden_layers=0, + hidden_dim=self.hidden_dim, + activation=nn.tanh, + kernel_init=common.init_orth( + common.calc_gain('tanh') + ), + last_layer_kernel_init=common.init_orth( + common.calc_gain('linear') + ) + ) + + if self.value_ensemble_size > 1: + self.v_head = common.EnsembleValueHead( + n=self.value_ensemble_size, **value_head_kwargs) + else: + self.v_head = common.ValueHead(**value_head_kwargs) + + def __call__(self, x, carry=None): + raise NotImplementedError + + def initialize_carry( + self, + rng: chex.PRNGKey, + batch_dims: Tuple[int] = ()) -> Tuple[chex.ArrayTree, chex.ArrayTree]: + """Initialize hidden state of LSTM.""" + if self.recurrent_arch is not None: + if self.recurrent_arch == 's5': + return s5.S5EncoderStack.initialize_carry( # Since conj_sym=True + rng, batch_dims, self.recurrent_hidden_dim//2, self.s5_n_layers + ) + elif self.recurrent_arch == 'transformer': + return transformer.ScannedTransformer.initialize_carry( + self.recurrent_hidden_dim, batch_dims) + else: + return common.ScannedRNN.initialize_carry( + rng, batch_dims, self.recurrent_hidden_dim, self.recurrent_arch) + else: + raise ValueError('Model is not recurrent.') + + @property + def is_recurrent(self): + return self.recurrent_arch is not None + + +class ACStudentActorModel(BasicModel): + def __call__(self, x, carry=None, reset=None): + """ + Inputs: + x: B x h x w observations + hxs: B x hx_dim hidden states + masks: B length vector of done masks + """ + img = x + + if self.rnn is not None: + batch_dims = img.shape[:-3] + x = self.conv(img) + else: + batch_dims = img.shape[:-3] + x = self.conv(img) + + if self.is_soft_moe: + initial_shape = x.shape + if len(initial_shape) == 5: + a, n, h, w, f = x.shape + x = einops.rearrange(x, "a n ... -> (a n) ...", a=a, n=n) + + x = einops.rearrange(x, "... w h c -> ... (w h) c") + + x = self.moe(x) + + if len(initial_shape) == 5: + x = einops.rearrange(x, "(a n) ... -> a n ...", a=a, n=n) + + x = x.reshape(*batch_dims, -1) + x = self.after_conv(x) + + if self.rnn is not None: + if self.recurrent_arch == 's5': + x = self.embed_pre_s5(x) + carry, x = self.rnn(carry, x, reset) + elif self.recurrent_arch == 'transformer': + x = self.rnn(carry, (x, mask, reset)) + else: + carry, x = self.rnn(carry, (x, reset)) + + logits = self.pi_head(x) + return logits, carry + + +class ACStudentActorModelMlp(BasicModel): + + conv_encoder: bool = False + + def __call__(self, x, carry=None, reset=None): + """ + Inputs: + x: B x h x w observations + hxs: B x hx_dim hidden states + masks: B length vector of done masks + """ + img = x + + if self.rnn is not None: + batch_dims = img.shape[:-1] + x = self.linear_encoder(img) + x = x.reshape(*batch_dims, -1) + else: + batch_dims = img.shape[:-1] + x = self.linear_encoder(img) + x = x.reshape(*batch_dims, -1) + + if self.rnn is not None: + if self.recurrent_arch == 's5': + x = self.embed_pre_s5(x) + carry, x = self.rnn(carry, x, reset) + elif self.recurrent_arch == 'transformer': + x = self.rnn(carry, (x, mask, reset)) + else: + carry, x = self.rnn(carry, (x, reset)) + + logits = self.pi_head(x) + return logits, carry + + +class ACStudentCriticModel(BasicModel): + + def __call__(self, x, carry=None, reset=None): + """ + Inputs: + x: B x h x w observations + hxs: B x hx_dim hidden states + masks: B length vector of done masks + """ + img = x + + if self.rnn is not None: + batch_dims = img.shape[:-3] + x = self.conv(img) + else: + batch_dims = img.shape[:-3] + x = self.conv(img) + + if self.is_soft_moe: + initial_shape = x.shape + if len(initial_shape) == 5: + a, n, h, w, f = x.shape + x = einops.rearrange(x, "a n ... -> (a n) ...", a=a, n=n) + + x = einops.rearrange(x, "... w h c -> ... (w h) c") + + x = self.moe(x) + + if len(initial_shape) == 5: + x = einops.rearrange(x, "(a n) ... -> a n ...", a=a, n=n) + + x = x.reshape(*batch_dims, -1) + x = self.after_conv(x) + + if self.rnn is not None: + if self.recurrent_arch == 's5': + x = self.embed_pre_s5(x) + carry, x = self.rnn(carry, x, reset) + else: + carry, x = self.rnn(carry, (x, reset)) + + v = self.v_head(x) + + return v, carry + + +class ACStudentCriticModelMlp(BasicModel): + + conv_encoder: bool = False + + def __call__(self, x, carry=None, reset=None): + """ + Inputs: + x: B x h x w observations + hxs: B x hx_dim hidden states + masks: B length vector of done masks + """ + img = x + + if self.rnn is not None: + batch_dims = img.shape[:-1] + x = self.linear_encoder(img) + x = x.reshape(*batch_dims, -1) + else: + batch_dims = img.shape[:-1] + x = self.linear_encoder(img) + x = x.reshape(*batch_dims, -1) + + # NOTE: Continue here tomorrow + # Is x reshape of shape zero?? + if self.rnn is not None: + if self.recurrent_arch == 's5': + x = self.embed_pre_s5(x) + carry, x = self.rnn(carry, x, reset) + elif self.recurrent_arch == 'transformer': + x = self.rnn(carry, (x, mask, reset)) + else: + carry, x = self.rnn(carry, (x, reset)) + + v = self.v_head(x) + + return v, carry + + +class ACStudentModel(BasicModel): + def __call__(self, x, carry=None, reset=None): + """ + Inputs: + x: B x h x w observations + hxs: B x hx_dim hidden states + masks: B length vector of done masks + """ + img = x + + if self.rnn is not None: + batch_dims = img.shape[:-3] + x = self.conv(img) + x = x.reshape(*batch_dims, -1) + else: + batch_dims = img.shape[:-3] + x = self.conv(img) + x = x.reshape(*batch_dims, -1) + + if self.rnn is not None: + if self.recurrent_arch == 's5': + x = self.embed_pre_s5(x) + carry, x = self.rnn(carry, x, reset) + elif self.recurrent_arch == 'transformer': + x = self.rnn(carry, (x, mask, reset)) + else: + carry, x = self.rnn(carry, (x, reset)) + + v = self.v_head(x) + + logits = self.pi_head(x) + + return v, logits, carry + + +class ACTeacherModel(BasicModel): + """ + Original teacher model from Dennis et al, 2020. It is identical ins + high-level spec to the student model, but with the additional fwd logic + of concatenating a noise vector. + """ + + def __call__(self, x, carry=None, reset=None): + """ + Inputs: + x: B x h x w observations + hxs: B x hx_dim hidden states + masks: B length vector of done masks + """ + img = x['image'] + time = x['time'] + noise = x.get('noise') + aux = x.get('aux') + + if self.rnn is not None: + batch_dims = img.shape[:2] + x = self.conv(img).reshape(*batch_dims, -1) + else: + batch_dims = img.shape[:1] + x = self.conv(img).reshape(*batch_dims, -1) + + if self.fc_scalar is not None: + if self.n_scalar_embeddings == 0: + time /= self.max_scalar + + scalar_emb = self.fc_scalar(time).reshape(*batch_dims, -1) + x = jnp.concatenate([x, scalar_emb], axis=-1) + + if noise is not None: + noise = noise.reshape(*batch_dims, -1) + x = jnp.concatenate([x, noise], axis=-1) + + if aux is not None: + x = jnp.concatenate([x, aux], axis=-1) + + if self.rnn is not None: + if self.recurrent_arch == 's5': + x = self.embed_pre_s5(x) + carry, x = self.rnn(carry, x, reset) + else: + carry, x = self.rnn(carry, (x, reset)) + + logits = self.pi_head(x) + + v = self.v_head(x) + + return v, logits, carry + + +# Register models +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +register( + env_group_id='Overcooked', model_id='default_student_actor_moe', + entry_point=module_path + ':ACStudentActorModelSoftMoE') + +register( + env_group_id='Overcooked', model_id='default_student_critic_moe', + entry_point=module_path + ':ACStudentCriticModelMoE') + +register( + env_group_id='Overcooked', model_id='default_student_actor_cnn', + entry_point=module_path + ':ACStudentActorModel') + +register( + env_group_id='Overcooked', model_id='default_student_critic_cnn', + entry_point=module_path + ':ACStudentCriticModel') + +register( + env_group_id='Overcooked', model_id='default_student_actor_mlp', + entry_point=module_path + ':ACStudentActorModelMlp') + +register( + env_group_id='Overcooked', model_id='default_student_critic_mlp', + entry_point=module_path + ':ACStudentCriticModelMlp') + +register( + env_group_id='Overcooked', model_id='default_student_cnn', + entry_point=module_path + ':ACStudentModel') + +register( + env_group_id='Overcooked', model_id='default_teacher_cnn', + entry_point=module_path + ':ACTeacherModel') diff --git a/src/minimax/models/registration.py b/src/minimax/models/registration.py new file mode 100644 index 0000000..afc84e5 --- /dev/null +++ b/src/minimax/models/registration.py @@ -0,0 +1,51 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import importlib +import copy + +# Global registry +registered_models = {} + + +def _load(name): + mod_name, attr_name = name.split(":") + mod = importlib.import_module(mod_name) + fn = getattr(mod, attr_name) + return fn + + +def _get_register_id(env_group_id, model_id): + return f"{env_group_id.lower()}-{model_id}" + + +def register(env_group_id, model_id, entry_point): + register_id = _get_register_id(env_group_id, model_id) + if register_id in registered_models: + raise ValueError(f'A model has already been registered as {register_id}.') + else: + registered_models[register_id] = entry_point + + +def make( + env_name, model_name=None, **model_kwargs): + env_group_id = env_name.split('-')[0].lstrip('UED') + model_id = model_name + + register_id = _get_register_id(env_group_id, model_id) + if register_id not in registered_models: + raise ValueError(f'No model for {register_id} found.') + else: + entry = registered_models[register_id] + + if callable(entry): + model = entry(**model_kwargs) + else: + model = _load(entry)(**model_kwargs) + + return model \ No newline at end of file diff --git a/src/minimax/models/rnn.py b/src/minimax/models/rnn.py new file mode 100644 index 0000000..01e8b93 --- /dev/null +++ b/src/minimax/models/rnn.py @@ -0,0 +1,98 @@ +""" +Copyright 2018 The JAX Authors. + +This file is based on the OptimizedLSTMCell class from +https://github.com/google/jax + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from functools import partial +from typing import Any, Tuple, Mapping + +import numpy as np +import jax.numpy as jnp +import flax.linen as nn +from flax.linen.dtypes import promote_dtype +from flax.linen.module import compact +from flax.linen.recurrent import DenseParams + + +Array = Any + + +class CustomOptimizedLSTMCell(nn.OptimizedLSTMCell): + @compact + def __call__(self, carry: Tuple[Array, Array], + inputs: Array) -> Tuple[Tuple[Array, Array], Array]: + r"""An optimized long short-term memory (LSTM) cell. + + Args: + carry: the hidden state of the LSTM cell, initialized using + `LSTMCell.initialize_carry`. + inputs: an ndarray with the input for the current time step. All + dimensions except the final are considered batch dimensions. + + Returns: + A tuple with the new carry and the output. + """ + c, h = carry + hidden_features = h.shape[-1] + + def _concat_dense(inputs: Array, + params: Mapping[str, Tuple[Array, Array]], + use_bias: bool = True) -> Array: + # Concatenates the individual kernels and biases, given in params, into a + # single kernel and single bias for efficiency before applying them using + # dot_general. + kernels, biases = zip(*params.values()) + kernel = jnp.concatenate(kernels, axis=-1) + if use_bias: + bias = jnp.concatenate(biases, axis=-1) + else: + bias = None + inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) + y = jnp.dot(inputs, kernel) + if use_bias: + y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) + + # Split the result back into individual (i, f, g, o) outputs. + split_indices = np.cumsum([kernel.shape[-1] for kernel in kernels[:-1]]) + ys = jnp.split(y, split_indices, axis=-1) + return dict(zip(params.keys(), ys)) + + # Create params with the same names/shapes as `LSTMCell` for compatibility. + dense_params_h = {} + dense_params_i = {} + for component in ['i', 'f', 'g', 'o']: + dense_params_i[component] = DenseParams( + features=hidden_features, use_bias=True, + param_dtype=self.param_dtype, + kernel_init=self.kernel_init, bias_init=self.bias_init, + name=f'i{component}')(inputs) + dense_params_h[component] = DenseParams( + features=hidden_features, use_bias=True, + param_dtype=self.param_dtype, + kernel_init=self.recurrent_kernel_init, bias_init=self.bias_init, + name=f'h{component}')(h) + dense_h = _concat_dense(h, dense_params_h, use_bias=True) + dense_i = _concat_dense(inputs, dense_params_i, use_bias=True) + + i = self.gate_fn(dense_h['i'] + dense_i['i']) + f = self.gate_fn(dense_h['f'] + dense_i['f']) + g = self.activation_fn(dense_h['g'] + dense_i['g']) + o = self.gate_fn(dense_h['o'] + dense_i['o']) + + new_c = f * c + i * g + new_h = o * self.activation_fn(new_c) + return (new_c, new_h), new_h diff --git a/src/minimax/models/s5.py b/src/minimax/models/s5.py new file mode 100644 index 0000000..bacab12 --- /dev/null +++ b/src/minimax/models/s5.py @@ -0,0 +1,706 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This file is modified from +https://github.com/luchris429/purejaxrl + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 +""" + +from functools import partial +import jax +import jax.numpy as np +import jax.numpy as jnp +from flax import linen as nn +from jax.nn.initializers import lecun_normal, normal +from jax import random +from jax.numpy.linalg import eigh +from jax.scipy.linalg import block_diag + + +def log_step_initializer(dt_min=0.001, dt_max=0.1): + """ Initialize the learnable timescale Delta by sampling + uniformly between dt_min and dt_max. + Args: + dt_min (float32): minimum value + dt_max (float32): maximum value + Returns: + init function + """ + def init(key, shape): + """ Init function + Args: + key: jax random key + shape tuple: desired shape + Returns: + sampled log_step (float32) + """ + return random.uniform(key, shape) * ( + np.log(dt_max) - np.log(dt_min) + ) + np.log(dt_min) + + return init + + +def init_log_steps(key, input): + """ Initialize an array of learnable timescale parameters + Args: + key: jax random key + input: tuple containing the array shape H and + dt_min and dt_max + Returns: + initialized array of timescales (float32): (H,) + """ + H, dt_min, dt_max = input + log_steps = [] + for i in range(H): + key, skey = random.split(key) + log_step = log_step_initializer( + dt_min=dt_min, dt_max=dt_max)(skey, shape=(1,)) + log_steps.append(log_step) + + return np.array(log_steps) + + +def init_VinvB(init_fun, rng, shape, Vinv): + """ Initialize B_tilde=V^{-1}B. First samples B. Then compute V^{-1}B. + Note we will parameterize this with two different matrices for complex + numbers. + Args: + init_fun: the initialization function to use, e.g. lecun_normal() + rng: jax random key to be used with init function. + shape (tuple): desired shape (P,H) + Vinv: (complex64) the inverse eigenvectors used for initialization + Returns: + B_tilde (complex64) of shape (P,H,2) + """ + B = init_fun(rng, shape) + VinvB = Vinv @ B + VinvB_real = VinvB.real + VinvB_imag = VinvB.imag + return np.concatenate((VinvB_real[..., None], VinvB_imag[..., None]), axis=-1) + + +def trunc_standard_normal(key, shape): + """ Sample C with a truncated normal distribution with standard deviation 1. + Args: + key: jax random key + shape (tuple): desired shape, of length 3, (H,P,_) + Returns: + sampled C matrix (float32) of shape (H,P,2) (for complex parameterization) + """ + H, P, _ = shape + Cs = [] + for i in range(H): + key, skey = random.split(key) + C = lecun_normal()(skey, shape=(1, P, 2)) + Cs.append(C) + return np.array(Cs)[:, 0] + + +def init_CV(init_fun, rng, shape, V): + """ Initialize C_tilde=CV. First sample C. Then compute CV. + Note we will parameterize this with two different matrices for complex + numbers. + Args: + init_fun: the initialization function to use, e.g. lecun_normal() + rng: jax random key to be used with init function. + shape (tuple): desired shape (H,P) + V: (complex64) the eigenvectors used for initialization + Returns: + C_tilde (complex64) of shape (H,P,2) + """ + C_ = init_fun(rng, shape) + C = C_[..., 0] + 1j * C_[..., 1] + CV = C @ V + CV_real = CV.real + CV_imag = CV.imag + return np.concatenate((CV_real[..., None], CV_imag[..., None]), axis=-1) + + +# Discretization functions +def discretize_bilinear(Lambda, B_tilde, Delta): + """ Discretize a diagonalized, continuous-time linear SSM + using bilinear transform method. + Args: + Lambda (complex64): diagonal state matrix (P,) + B_tilde (complex64): input matrix (P, H) + Delta (float32): discretization step sizes (P,) + Returns: + discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) + """ + Identity = np.ones(Lambda.shape[0]) + + BL = 1 / (Identity - (Delta / 2.0) * Lambda) + Lambda_bar = BL * (Identity + (Delta / 2.0) * Lambda) + B_bar = (BL * Delta)[..., None] * B_tilde + return Lambda_bar, B_bar + + +def discretize_zoh(Lambda, B_tilde, Delta): + """ Discretize a diagonalized, continuous-time linear SSM + using zero-order hold method. + Args: + Lambda (complex64): diagonal state matrix (P,) + B_tilde (complex64): input matrix (P, H) + Delta (float32): discretization step sizes (P,) + Returns: + discretized Lambda_bar (complex64), B_bar (complex64) (P,), (P,H) + """ + Identity = np.ones(Lambda.shape[0]) + Lambda_bar = np.exp(Lambda * Delta) + B_bar = (1/Lambda * (Lambda_bar-Identity))[..., None] * B_tilde + return Lambda_bar, B_bar + + +# Parallel scan operations +@jax.vmap +def binary_operator(q_i, q_j): + """ Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. + Args: + q_i: tuple containing A_i and Bu_i at position i (P,), (P,) + q_j: tuple containing A_j and Bu_j at position j (P,), (P,) + Returns: + new element ( A_out, Bu_out ) + """ + A_i, b_i = q_i + A_j, b_j = q_j + return A_j * A_i, A_j * b_i + b_j + +# Parallel scan operations + + +@jax.vmap +def binary_operator_reset(q_i, q_j): + """ Binary operator for parallel scan of linear recurrence. Assumes a diagonal matrix A. + Args: + q_i: tuple containing A_i and Bu_i at position i (P,), (P,) + q_j: tuple containing A_j and Bu_j at position j (P,), (P,) + Returns: + new element ( A_out, Bu_out ) + """ + A_i, b_i, c_i = q_i + A_j, b_j, c_j = q_j + return ( + (A_j * A_i)*(1 - c_j) + A_j * c_j, + (A_j * b_i + b_j)*(1 - c_j) + b_j * c_j, + c_i * (1 - c_j) + c_j, + ) + + +def apply_ssm(Lambda_bar, B_bar, C_tilde, hidden, input_sequence, resets, conj_sym, bidirectional): + """ Compute the LxH output of discretized SSM given an LxH input. + Args: + Lambda_bar (complex64): discretized diagonal state matrix (P,) + B_bar (complex64): discretized input matrix (P, H) + C_tilde (complex64): output matrix (H, P) + input_sequence (float32): input sequence of features (L, H) + reset (bool): input sequence of features (L,) + conj_sym (bool): whether conjugate symmetry is enforced + bidirectional (bool): whether bidirectional setup is used, + Note for this case C_tilde will have 2P cols + Returns: + ys (float32): the SSM outputs (S5 layer preactivations) (L, H) + """ + Lambda_elements = Lambda_bar * jnp.ones((input_sequence.shape[0], + Lambda_bar.shape[0])) + Bu_elements = jax.vmap(lambda u: B_bar @ u)(input_sequence) + + Lambda_elements = jnp.concatenate([ + jnp.ones((1, Lambda_bar.shape[0])), + Lambda_elements, + ]) + + Bu_elements = jnp.concatenate([ + hidden, + Bu_elements, + ]) + + resets = jnp.concatenate([ + jnp.zeros(1), + resets, + ]) + + _, xs, _ = jax.lax.associative_scan( + binary_operator_reset, (Lambda_elements, Bu_elements, resets)) + xs = xs[1:] + + if conj_sym: + return xs[np.newaxis, -1], jax.vmap(lambda x: 2*(C_tilde @ x).real)(xs) + else: + return xs[np.newaxis, -1], jax.vmap(lambda x: (C_tilde @ x).real)(xs) + + +class S5SSM(nn.Module): + Lambda_re_init: jax.Array + Lambda_im_init: jax.Array + V: jax.Array + Vinv: jax.Array + + H: int + P: int + C_init: str + discretization: str + dt_min: float + dt_max: float + conj_sym: bool = True + clip_eigs: bool = False + bidirectional: bool = False + step_rescale: float = 1.0 + + """ The S5 SSM + Args: + Lambda_re_init (complex64): Real part of init diag state matrix (P,) + Lambda_im_init (complex64): Imag part of init diag state matrix (P,) + V (complex64): Eigenvectors used for init (P,P) + Vinv (complex64): Inverse eigenvectors used for init (P,P) + H (int32): Number of features of input seq + P (int32): state size + C_init (string): Specifies How C is initialized + Options: [trunc_standard_normal: sample from truncated standard normal + and then multiply by V, i.e. C_tilde=CV. + lecun_normal: sample from Lecun_normal and then multiply by V. + complex_normal: directly sample a complex valued output matrix + from standard normal, does not multiply by V] + conj_sym (bool): Whether conjugate symmetry is enforced + clip_eigs (bool): Whether to enforce left-half plane condition, i.e. + constrain real part of eigenvalues to be negative. + True recommended for autoregressive task/unbounded sequence lengths + Discussed in https://arxiv.org/pdf/2206.11893.pdf. + bidirectional (bool): Whether model is bidirectional, if True, uses two C matrices + discretization: (string) Specifies discretization method + options: [zoh: zero-order hold method, + bilinear: bilinear transform] + dt_min: (float32): minimum value to draw timescale values from when + initializing log_step + dt_max: (float32): maximum value to draw timescale values from when + initializing log_step + step_rescale: (float32): allows for uniformly changing the timescale parameter, e.g. after training + on a different resolution for the speech commands benchmark + """ + + def setup(self): + """Initializes parameters once and performs discretization each time + the SSM is applied to a sequence + """ + if self.conj_sym: + # Need to account for case where we actually sample real B and C, and then multiply + # by the half sized Vinv and possibly V + local_P = 2*self.P + else: + local_P = self.P + + # Initialize diagonal state to state matrix Lambda (eigenvalues) + self.Lambda_re = self.param( + "Lambda_re", lambda rng, shape: self.Lambda_re_init, (None,)) + self.Lambda_im = self.param( + "Lambda_im", lambda rng, shape: self.Lambda_im_init, (None,)) + if self.clip_eigs: + self.Lambda = np.clip(self.Lambda_re, None, - + 1e-4) + 1j * self.Lambda_im + else: + self.Lambda = self.Lambda_re + 1j * self.Lambda_im + + # Initialize input to state (B) matrix + B_init = lecun_normal() + B_shape = (local_P, self.H) + self.B = self.param("B", + lambda rng, shape: init_VinvB(B_init, + rng, + shape, + self.Vinv), + B_shape) + B_tilde = self.B[..., 0] + 1j * self.B[..., 1] + + # Initialize state to output (C) matrix + if self.C_init in ["trunc_standard_normal"]: + C_init = trunc_standard_normal + C_shape = (self.H, local_P, 2) + elif self.C_init in ["lecun_normal"]: + C_init = lecun_normal() + C_shape = (self.H, local_P, 2) + elif self.C_init in ["complex_normal"]: + C_init = normal(stddev=0.5 ** 0.5) + else: + raise NotImplementedError( + "C_init method {} not implemented".format(self.C_init)) + + if self.C_init in ["complex_normal"]: + if self.bidirectional: + C = self.param("C", C_init, (self.H, 2 * self.P, 2)) + self.C_tilde = C[..., 0] + 1j * C[..., 1] + + else: + C = self.param("C", C_init, (self.H, self.P, 2)) + self.C_tilde = C[..., 0] + 1j * C[..., 1] + + else: + if self.bidirectional: + self.C1 = self.param("C1", + lambda rng, shape: init_CV( + C_init, rng, shape, self.V), + C_shape) + self.C2 = self.param("C2", + lambda rng, shape: init_CV( + C_init, rng, shape, self.V), + C_shape) + + C1 = self.C1[..., 0] + 1j * self.C1[..., 1] + C2 = self.C2[..., 0] + 1j * self.C2[..., 1] + self.C_tilde = np.concatenate((C1, C2), axis=-1) + + else: + self.C = self.param("C", + lambda rng, shape: init_CV( + C_init, rng, shape, self.V), + C_shape) + + self.C_tilde = self.C[..., 0] + 1j * self.C[..., 1] + + # Initialize feedthrough (D) matrix + self.D = self.param("D", normal(stddev=1.0), (self.H,)) + + # Initialize learnable discretization timescale value + self.log_step = self.param("log_step", + init_log_steps, + (self.P, self.dt_min, self.dt_max)) + step = self.step_rescale * np.exp(self.log_step[:, 0]) + + # Discretize + if self.discretization in ["zoh"]: + self.Lambda_bar, self.B_bar = discretize_zoh( + self.Lambda, B_tilde, step) + elif self.discretization in ["bilinear"]: + self.Lambda_bar, self.B_bar = discretize_bilinear( + self.Lambda, B_tilde, step) + else: + raise NotImplementedError( + "Discretization method {} not implemented".format(self.discretization)) + + def __call__(self, hidden, input_sequence, resets): + """ + Compute the LxH output of the S5 SSM given an LxH input sequence + using a parallel scan. + Args: + input_sequence (float32): input sequence (L, H) + resets (bool): input sequence (L,) + Returns: + output sequence (float32): (L, H) + """ + hidden, ys = apply_ssm( + self.Lambda_bar, + self.B_bar, + self.C_tilde, + hidden, + input_sequence, + resets, + self.conj_sym, + self.bidirectional) + # Add feedthrough matrix output Du; + Du = jax.vmap(lambda u: self.D * u)(input_sequence) + return hidden, ys + Du + + +class SequenceLayer(nn.Module): + """ Defines a single S5 layer, with S5 SSM, nonlinearity, + dropout, batch/layer norm, etc. + Args: + ssm (nn.Module): the SSM to be used (i.e. S5 ssm) + dropout (float32): dropout rate + d_model (int32): this is the feature size of the layer inputs and outputs + we usually refer to this size as H + activation (string): Type of activation function to use + training (bool): whether in training mode or not + prenorm (bool): apply prenorm if true or postnorm if false + batchnorm (bool): apply batchnorm if true or layernorm if false + bn_momentum (float32): the batchnorm momentum if batchnorm is used + step_rescale (float32): allows for uniformly changing the timescale parameter, + e.g. after training on a different resolution for + the speech commands benchmark + """ + ssm: nn.Module + d_model: int + activation: str = "gelu" + layernorm_pos: str = None # ['pre', 'post', None] + step_rescale: float = 1.0 + + def setup(self): + """Initializes the ssm, batch/layer norm and dropout + """ + self.seq = self.ssm(step_rescale=self.step_rescale) + + if self.activation in ["full_glu"]: + self.out1 = nn.Dense(self.d_model) + self.out2 = nn.Dense(self.d_model) + elif self.activation in ["half_glu1", "half_glu2"]: + self.out2 = nn.Dense(self.d_model) + + self.norm = nn.LayerNorm() + self.drop = lambda x: x + + def __call__(self, hidden, x, d): + """ + Compute the LxH output of S5 layer given an LxH input. + Args: + x (float32): input sequence (L, B, d_model) + d (bool): reset signal (L,B) + Returns: + output sequence (float32): (L, B, d_model) + """ + is_one_step = len(hidden.shape) == 2 + if is_one_step: # Add time axis + hidden = hidden[jnp.newaxis, :] + + skip = x + if self.layernorm_pos == 'pre': + x = self.norm(x) + hidden, x = jax.vmap(self.seq, in_axes=1, out_axes=1)(hidden, x, d) + + if self.activation in ["full_glu"]: + x = self.drop(nn.gelu(x)) + x = self.out1(x) * jax.nn.sigmoid(self.out2(x)) + x = self.drop(x) + elif self.activation in ["half_glu1"]: + x = self.drop(nn.gelu(x)) + x = x * jax.nn.sigmoid(self.out2(x)) + x = self.drop(x) + elif self.activation in ["half_glu2"]: + # Only apply GELU to the gate input + x1 = self.drop(nn.gelu(x)) + x = x * jax.nn.sigmoid(self.out2(x1)) + x = self.drop(x) + elif self.activation in ["gelu"]: + x = self.drop(nn.gelu(x)) + else: + raise NotImplementedError( + "Activation: {} not implemented".format(self.activation)) + + x = skip + x + if self.layernorm_pos == 'post': + x = self.norm(x) + if is_one_step: + hidden = hidden.squeeze(0) + + return hidden, x + + @staticmethod + def initialize_carry(batch_size, hidden_size): + return jnp.zeros((batch_size, hidden_size), dtype=jnp.complex64) + + +def init_S5SSM( + H, + P, + Lambda_re_init, + Lambda_im_init, + V, + Vinv, + C_init, + discretization, + dt_min, + dt_max, + conj_sym, + clip_eigs, + bidirectional): + """Convenience function that will be used to initialize the SSM. + Same arguments as defined in S5SSM above.""" + return partial(S5SSM, + H=H, + P=P, + Lambda_re_init=Lambda_re_init, + Lambda_im_init=Lambda_im_init, + V=V, + Vinv=Vinv, + C_init=C_init, + discretization=discretization, + dt_min=dt_min, + dt_max=dt_max, + conj_sym=conj_sym, + clip_eigs=clip_eigs, + bidirectional=bidirectional) + + +def make_HiPPO(N): + """ Create a HiPPO-LegS matrix. + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Args: + N (int32): state size + Returns: + N x N HiPPO LegS matrix + """ + P = np.sqrt(1 + 2 * np.arange(N)) + A = P[:, np.newaxis] * P[np.newaxis, :] + A = np.tril(A) - np.diag(np.arange(N)) + return -A + + +def make_NPLR_HiPPO(N): + """ + Makes components needed for NPLR representation of HiPPO-LegS + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Args: + N (int32): state size + Returns: + N x N HiPPO LegS matrix, low-rank factor P, HiPPO input matrix B + """ + # Make -HiPPO + hippo = make_HiPPO(N) + + # Add in a rank 1 term. Makes it Normal. + P = np.sqrt(np.arange(N) + 0.5) + + # HiPPO also specifies the B matrix + B = np.sqrt(2 * np.arange(N) + 1.0) + return hippo, P, B + + +def make_DPLR_HiPPO(N): + """ + Makes components needed for DPLR representation of HiPPO-LegS + From https://github.com/srush/annotated-s4/blob/main/s4/s4.py + Note, we will only use the diagonal part + Args: + N: + Returns: + eigenvalues Lambda, low-rank term P, conjugated HiPPO input matrix B, + eigenvectors V, HiPPO B pre-conjugation + """ + A, P, B = make_NPLR_HiPPO(N) + + S = A + P[:, np.newaxis] * P[np.newaxis, :] + + S_diag = np.diagonal(S) + Lambda_real = np.mean(S_diag) * np.ones_like(S_diag) + + # Diagonalize S to V \Lambda V^* + Lambda_imag, V = eigh(S * -1j) + + P = V.conj().T @ P + B_orig = B + B = V.conj().T @ B + return Lambda_real + 1j * Lambda_imag, P, B, V, B_orig + + +class S5EncoderStack(nn.Module): + """ Defines a stack of S5 layers to be used as an encoder. + Args: + ssm (nn.Module): the SSM to be used (i.e. S5 ssm) + d_model (int32): this is the feature size of the layer inputs and outputs + we usually refer to this size as H + n_layers (int32): the number of S5 layers to stack + activation (string): Type of activation function to use + dropout (float32): dropout rate + training (bool): whether in training mode or not + prenorm (bool): apply prenorm if true or postnorm if false + batchnorm (bool): apply batchnorm if true or layernorm if false + bn_momentum (float32): the batchnorm momentum if batchnorm is used + step_rescale (float32): allows for uniformly changing the timescale parameter, + e.g. after training on a different resolution for + the speech commands benchmark + """ + ssm: nn.Module + d_model: int + n_layers: int + layernorm_pos: str = None + activation: str = "gelu" + + def setup(self): + """ + Initializes a linear encoder and the stack of S5 layers. + """ + self.layers = [ + SequenceLayer( + ssm=self.ssm, + d_model=self.d_model, + activation=self.activation, + layernorm_pos=self.layernorm_pos, + ) + for _ in range(self.n_layers) + ] + + def __call__(self, hidden, x, reset): + """ + Compute the BxLxH output of the stacked encoder given an Lxd_input + input sequence. + Args: + x (float32): input sequence (L, d_input) + Returns: + output sequence (float32): (L, d_model) + """ + new_hiddens = [] + for i, layer in enumerate(self.layers): + new_h, x = layer(hidden[i], x, reset) + new_hiddens.append(new_h) + + return new_hiddens, x + + @staticmethod + def initialize_carry(rng, batch_dims, hidden_dim, n_layers): + # Use a dummy key since the default state init fn is just zeros. + return [jnp.zeros((*batch_dims, hidden_dim), dtype=jnp.complex64) for _ in range(n_layers)] + + +BatchS5EncoderStack = nn.vmap( + S5EncoderStack, + in_axes=(1, 1, 1), + out_axes=1, + variable_axes={"params": None}, + split_rngs={"params": False}, axis_name='batch') + + +def make_s5_encoder_stack( + input_dim, + ssm_state_dim, + n_blocks=1, + n_layers=4, + discretization='zoh', + dt_min=0.001, + dt_max=0.1, + conj_sym=True, + clip_eigs=False, + bidirectional=False, + activation="half_glu1", + layernorm_pos=None): + block_size = int(ssm_state_dim / n_blocks) + + Lambda, _, B, V, B_orig = make_DPLR_HiPPO(block_size) + + if conj_sym: + block_size = block_size // 2 + ssm_state_dim = ssm_state_dim // 2 + + Lambda = Lambda[:block_size] + V = V[:, :block_size] + Vc = V.conj().T + + Lambda = (Lambda*np.ones((n_blocks, block_size))).ravel() + V = block_diag(*([V]*n_blocks)) + Vinv = block_diag(*([Vc]*n_blocks)) + + ssm_init_fn = init_S5SSM( + H=input_dim, + P=ssm_state_dim, + Lambda_re_init=Lambda.real, + Lambda_im_init=Lambda.imag, + V=V, + Vinv=Vinv, + C_init="lecun_normal", + discretization=discretization, + dt_min=dt_min, + dt_max=dt_max, + conj_sym=conj_sym, + clip_eigs=clip_eigs, + bidirectional=bidirectional) + + return S5EncoderStack( + ssm=ssm_init_fn, + d_model=input_dim, + n_layers=n_layers, + activation=activation, + layernorm_pos=layernorm_pos + ) diff --git a/src/minimax/models/transformer.py b/src/minimax/models/transformer.py new file mode 100644 index 0000000..ca74cb4 --- /dev/null +++ b/src/minimax/models/transformer.py @@ -0,0 +1,264 @@ +from functools import partial +import flax.linen as nn +import jax.numpy as jnp + +import einops + + +import numpy as np + +from flax.linen.initializers import constant, orthogonal + + +class GRUGating(nn.Module): + + dim: int + scale_residual: bool = False + + def setup(self): + super().__init__() + self.gru = nn.GRUCell(self.dim, self.dim) + self.residual_scale = nn.Parameter( + jnp.ones(self.dim)) if self.scale_residual else None + + def __call__(self, x, residual): + if self.residual_scale is not None: + residual = residual * self.residual_scale + + gated_output = self.gru( + einops.rearrange(x, 'b n d -> (b n) d'), + einops.rearrange(residual, 'b n d -> (b n) d') + ) + + return gated_output.reshape_as(x) + + +class GatedEncoderBlock(nn.Module): + # Input dimension is needed here since it is equal to the output dimension (residual connection) + hidden_dim: int + num_heads: int + dim_feedforward: int + init_scale: float + use_fast_attention: bool + dropout_prob: float = 0. + + def setup(self): + # Attention layer + if self.use_fast_attention: + from fast_attention import make_fast_generalized_attention + raw_attention_fn = make_fast_generalized_attention( + self.hidden_dim // self.num_heads, + renormalize_attention=True, + nb_features=self.hidden_dim, + unidirectional=False + ) + self.self_attn = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + dropout_rate=self.dropout_prob, + attention_fn=raw_attention_fn, + kernel_init=nn.initializers.xavier_uniform(), + use_bias=False, + ) + else: + self.self_attn = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + dropout_rate=self.dropout_prob, + kernel_init=nn.initializers.xavier_uniform(), + use_bias=False, + ) + # Two-layer MLP + self.linear = [ + nn.Dense(self.dim_feedforward, kernel_init=nn.initializers.xavier_uniform( + ), bias_init=constant(0.0)), + nn.Dense(self.hidden_dim, kernel_init=nn.initializers.xavier_uniform( + ), bias_init=constant(0.0)) + ] + # Layers to apply in between the main layers + self.gate1 = GRUGating(dim=self.hidden_dim) + self.norm1 = nn.LayerNorm() + self.gate2 = GRUGating(dim=self.hidden_dim) + self.norm2 = nn.LayerNorm() + self.dropout = nn.Dropout(self.dropout_prob) + + def __call__(self, x, mask=None, deterministic=True): + + # Attention part + # masking is not compatible with fast self attention + x_norm1 = self.norm1(x) + if mask is not None and not self.use_fast_attention: + mask = jnp.repeat(nn.make_attention_mask( + mask, mask), self.num_heads, axis=-3) + attended = self.self_attn( + inputs_q=x_norm1, inputs_kv=x_norm1, mask=mask, deterministic=deterministic) + + # GRU gate + x = self.gate1(attended, x_norm1) + x = self.dropout(x, deterministic=deterministic) + + x_res = x + + # MLP part + x = self.norm2(x) + feedforward = self.linear[0](x) + feedforward = nn.relu(feedforward) + feedforward = self.linear[1](feedforward) + + # GRU Gate + x = self.gate2(x, x_res) + x = self.dropout(x, deterministic=deterministic) + return x + + +class EncoderBlock(nn.Module): + # Input dimension is needed here since it is equal to the output dimension (residual connection) + hidden_dim: int + num_heads: int + dim_feedforward: int + init_scale: float + use_fast_attention: bool + dropout_prob: float = 0. + + def setup(self): + # Attention layer + if self.use_fast_attention: + from fast_attention import make_fast_generalized_attention + raw_attention_fn = make_fast_generalized_attention( + self.hidden_dim // self.num_heads, + renormalize_attention=True, + nb_features=self.hidden_dim, + unidirectional=False + ) + self.self_attn = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + dropout_rate=self.dropout_prob, + attention_fn=raw_attention_fn, + kernel_init=nn.initializers.xavier_uniform(), + use_bias=False, + ) + else: + self.self_attn = nn.MultiHeadDotProductAttention( + num_heads=self.num_heads, + dropout_rate=self.dropout_prob, + kernel_init=nn.initializers.xavier_uniform(), + use_bias=False, + ) + # Two-layer MLP + self.linear = [ + nn.Dense(self.dim_feedforward, kernel_init=nn.initializers.xavier_uniform( + ), bias_init=constant(0.0)), + nn.Dense(self.hidden_dim, kernel_init=nn.initializers.xavier_uniform( + ), bias_init=constant(0.0)) + ] + # Layers to apply in between the main layers + self.norm1 = nn.LayerNorm() + self.norm2 = nn.LayerNorm() + self.dropout = nn.Dropout(self.dropout_prob) + + def __call__(self, x, mask=None, deterministic=True): + + # Attention part + # masking is not compatible with fast self attention + if mask is not None and not self.use_fast_attention: + mask = jnp.repeat(nn.make_attention_mask( + mask, mask), self.num_heads, axis=-3) + attended = self.self_attn( + inputs_q=x, inputs_kv=x, mask=mask, deterministic=deterministic) + + x = self.norm1(attended + x) + x = x + self.dropout(x, deterministic=deterministic) + + # MLP part + feedforward = self.linear[0](x) + feedforward = nn.relu(feedforward) + feedforward = self.linear[1](feedforward) + + x = self.norm2(feedforward+x) + x = x + self.dropout(x, deterministic=deterministic) + + return x + + +class Embedder(nn.Module): + hidden_dim: int + init_scale: float + scale_inputs: bool = True + activation: bool = False + + @nn.compact + def __call__(self, x, train: bool): + if self.scale_inputs: + x = nn.BatchNorm(use_running_average=not train)(x) + x = nn.Dense(self.hidden_dim, kernel_init=orthogonal( + self.init_scale), bias_init=constant(0.0))(x) + if self.activation: + x = nn.relu(x) + x = nn.BatchNorm(use_running_average=not train)(x) + return x + + +class ScannedTransformer(nn.Module): + + hidden_dim: int + init_scale: float + transf_num_layers: int + transf_num_heads: int + transf_dim_feedforward: int + transf_dropout_prob: float = 0 + deterministic: bool = True + return_embeddings: bool = False + use_fast_attention: bool = False + gated: bool = True + + def setup(self): + self.encoders = [ + GatedEncoderBlock( + self.hidden_dim, + self.transf_num_heads, + self.transf_dim_feedforward, + self.init_scale, + self.use_fast_attention, + self.transf_dropout_prob, + ) if self.gated else EncoderBlock( + self.hidden_dim, + self.transf_num_heads, + self.transf_dim_feedforward, + self.init_scale, + self.use_fast_attention, + self.transf_dropout_prob, + ) for _ in range(self.transf_num_layers) + ] + + @partial( + nn.scan, + variable_broadcast="params", + in_axes=0, + out_axes=0, + split_rngs={"params": False}, + ) + def __call__(self, carry, x): + hs = carry + embeddings, mask, done = x + + hs = jnp.where( + done[:, np.newaxis, np.newaxis], + self.initialize_carry(self.hidden_dim, *done.shape, 1), + hs + ) + embeddings = jnp.concatenate(( + hs, + embeddings, + ), axis=-2) + for layer in self.encoders: + embeddings = layer(embeddings, mask=mask, + deterministic=self.deterministic) + hs = embeddings[..., 0:1, :] + + # as y return the entire embeddings if required (i.e. transformer mixer), otherwise only agents' hs embeddings + if self.return_embeddings: + return hs, embeddings + else: + return hs, hs + + @staticmethod + def initialize_carry(hidden_size, *batch_size): + return jnp.zeros((*batch_size, hidden_size)) diff --git a/src/minimax/runners/__init__.py b/src/minimax/runners/__init__.py new file mode 100644 index 0000000..74cd61a --- /dev/null +++ b/src/minimax/runners/__init__.py @@ -0,0 +1,22 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .xp_runner import ExperimentRunner +from .eval_runner import EvalRunner +from .dr_runner import DRRunner +from .plr_runner import PLRRunner +from .paired_runner import PAIREDRunner + + +__all__ = [ + ExperimentRunner, + EvalRunner, + DRRunner, + PLRRunner, + PAIREDRunner +] \ No newline at end of file diff --git a/src/minimax/runners/dr_runner.py b/src/minimax/runners/dr_runner.py new file mode 100644 index 0000000..ecb0ef1 --- /dev/null +++ b/src/minimax/runners/dr_runner.py @@ -0,0 +1,458 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from typing import Tuple, Optional +import inspect + +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P +import optax +import flax +import flax.linen as nn +from flax.core.frozen_dict import FrozenDict + +import minimax.envs as envs +from minimax.util import pytree as _tree_util +from minimax.util.rl import ( + AgentPop, + VmapTrainState, + RolloutStorage, + RollingStats +) + + +class DRRunner: + """ + Orchestrates rollouts across one or more students. + The main components at play: + - AgentPop: Manages train state and batched inference logic + for a population of agents. + - BatchEnv: Manages environment step and reset logic, using a + populaton of agents. + - RolloutStorage: Manages the storing and sampling of collected txns. + - PPO: Handles PPO updates, which take a train state + batch of txns. + """ + + def __init__( + self, + env_name, + env_kwargs, + student_agents, + n_students=1, + n_parallel=1, + n_eval=1, + n_rollout_steps=256, + lr=1e-4, + lr_final=None, + lr_anneal_steps=0, + max_grad_norm=0.5, + discount=0.99, + gae_lambda=0.95, + adam_eps=1e-5, + normalize_return=False, + track_env_metrics=False, + n_unroll_rollout=1, + n_devices=1, + render=False): + + assert len(student_agents) == 1, 'Only one type of student supported.' + assert n_parallel % n_devices == 0, 'Num envs must be divisible by num devices.' + + self.n_students = n_students + self.n_parallel = n_parallel // n_devices + self.n_eval = n_eval + self.n_devices = n_devices + self.step_batch_size = n_students*n_eval*n_parallel + self.n_rollout_steps = n_rollout_steps + self.n_updates = 0 + self.lr = lr + self.lr_final = lr if lr_final is None else lr_final + self.lr_anneal_steps = lr_anneal_steps + self.max_grad_norm = max_grad_norm + self.adam_eps = adam_eps + self.normalize_return = normalize_return + self.track_env_metrics = track_env_metrics + self.n_unroll_rollout = n_unroll_rollout + self.render = render + + self.student_pop = AgentPop(student_agents[0], n_agents=n_students) + + self.env, self.env_params = envs.make( + env_name, + env_kwargs=env_kwargs + ) + self._action_shape = self.env.action_space().shape + + self.benv = envs.BatchEnv( + env_name=env_name, + n_parallel=self.n_parallel, + n_eval=self.n_eval, + env_kwargs=env_kwargs, + wrappers=['monitor_return', 'monitor_ep_metrics'] + ) + self.action_dtype = self.benv.env.action_space().dtype + + self.student_rollout = RolloutStorage( + discount=discount, + gae_lambda=gae_lambda, + n_steps=n_rollout_steps, + n_agents=n_students, + n_envs=self.n_parallel, + n_eval=self.n_eval, + action_space=self.env.action_space(), + obs_space=self.env.observation_space(), + agent=self.student_pop.agent, + ) + + monitored_metrics = self.benv.env.get_monitored_metrics() + self.rolling_stats = RollingStats( + names=monitored_metrics, + window=10, + ) + self._update_ep_stats = jax.vmap( + jax.vmap(self.rolling_stats.update_stats)) + + if self.render: + from envs.viz.grid_viz import GridVisualizer + self.viz = GridVisualizer() + self.viz.show() + + def reset(self, rng): + self.n_updates = 0 + + n_parallel = self.n_parallel*self.n_devices + + rngs, *vrngs = jax.random.split(rng, self.n_students+1) + obs, state, extra = self.benv.reset( + jnp.array(vrngs), n_parallel=n_parallel) + dummy_obs = jax.tree_util.tree_map( + lambda x: x[0], obs) # for one agent only + + rng, subrng = jax.random.split(rng) + if self.student_pop.agent.is_recurrent: + carry = self.student_pop.init_carry(subrng, obs) + self.zero_carry = jax.tree_map( + lambda x: x.at[:, :self.n_parallel].get(), carry) + else: + carry = None + + rng, subrng = jax.random.split(rng) + params = self.student_pop.init_params(subrng, dummy_obs) + + schedule_fn = optax.linear_schedule( + init_value=-float(self.lr), + end_value=-float(self.lr_final), + transition_steps=self.lr_anneal_steps, + ) + + tx = optax.chain( + optax.clip_by_global_norm(self.max_grad_norm), + optax.adam(learning_rate=float(self.lr), eps=self.adam_eps) + ) + + train_state = VmapTrainState.create( + apply_fn=self.student_pop.agent.evaluate, + params=params, + tx=tx + ) + + ep_stats = self.rolling_stats.reset_stats( + batch_shape=(self.n_students, n_parallel*self.n_eval)) + + start_state = state + + return ( + rng, + train_state, + state, + start_state, # Used to track metrics from starting state + obs, + carry, + extra, + ep_stats + ) + + def get_checkpoint_state(self, state): + _state = list(state) + _state[1] = state[1].state_dict + + return _state + + def load_checkpoint_state(self, runner_state, state): + runner_state = list(runner_state) + runner_state[1] = runner_state[1].load_state_dict(state[1]) + + return tuple(runner_state) + + @partial(jax.jit, static_argnums=(0, 2)) + def _get_transition( + self, + rng, + pop, + params, + rollout, + state, + start_state, + obs, + carry, + done, + extra=None): + # Sample action + value, pi_params, next_carry = pop.act(params, obs, carry, done) + + pi = pop.get_action_dist(pi_params, dtype=self.action_dtype) + rng, subrng = jax.random.split(rng) + action = pi.sample(seed=subrng) + log_pi = pi.log_prob(action) + + rng, *vrngs = jax.random.split(rng, self.n_students+1) + (next_obs, + next_state, + reward, + done, + info, + extra) = self.benv.step(jnp.array(vrngs), state, action, extra) + + next_start_state = jax.vmap(_tree_util.pytree_select)( + done, next_state, start_state + ) + + # Add transition to storage + step = (obs, action, reward, done, log_pi, value) + if carry is not None: + step += (carry,) + + rollout = self.student_rollout.append(rollout, *step) + + if self.render: + self.viz.render( + self.benv.env.params, + jax.tree_util.tree_map(lambda x: x[0][0], state)) + + return ( + rollout, + next_state, + next_start_state, + next_obs, + next_carry, + done, + info, + extra + ) + + @partial(jax.jit, static_argnums=(0,)) + def _rollout_students( + self, + rng, + train_state, + state, + start_state, + obs, + carry, + done, + extra=None, + ep_stats=None): + rollout = self.student_rollout.reset() + + rngs = jax.random.split(rng, self.n_rollout_steps) + + def _scan_rollout(scan_carry, rng): + rollout, state, start_state, obs, carry, done, extra, ep_stats, train_state = scan_carry + + next_scan_carry = \ + self._get_transition( + rng, + self.student_pop, + jax.lax.stop_gradient(train_state.params), + rollout, + state, + start_state, + obs, + carry, + done, + extra) + (rollout, + next_state, + next_start_state, + next_obs, + next_carry, + done, + info, + extra) = next_scan_carry + + ep_stats = self._update_ep_stats(ep_stats, done, info) + + return ( + rollout, + next_state, + next_start_state, + next_obs, + next_carry, + done, + extra, + ep_stats, + train_state), None + + (rollout, + state, + start_state, + obs, + carry, + done, + extra, + ep_stats, + train_state), _ = jax.lax.scan( + _scan_rollout, + (rollout, + state, + start_state, + obs, + carry, + done, + extra, + ep_stats, + train_state), + rngs, + length=self.n_rollout_steps, + ) + + return rollout, state, start_state, obs, carry, extra, ep_stats, train_state + + @partial(jax.jit, static_argnums=(0,)) + def _compile_stats(self, update_stats, ep_stats, env_metrics=None): + stats = jax.vmap(lambda info: jax.tree_map(lambda x: x.mean(), info))( + {k: ep_stats[k] for k in self.rolling_stats.names} + ) + stats.update(update_stats) + + if self.n_students > 1: + _stats = {} + for i in range(self.n_students): + _student_stats = jax.tree_util.tree_map( + lambda x: x[i], stats) # for agent0 + _stats.update( + {f'a{i}/{k}': v for k, v in _student_stats.items()}) + stats = _stats + + if self.track_env_metrics: + mean_env_metrics = jax.vmap(lambda info: jax.tree_map( + lambda x: x.mean(), info))(env_metrics) + mean_env_metrics = {f'env/{k}': v for k, + v in mean_env_metrics.items()} + + if self.n_students > 1: + _env_metrics = {} + for i in range(self.n_students): + _student_env_metrics = jax.tree_util.tree_map( + lambda x: x[i], mean_env_metrics) # for agent0 + _env_metrics.update( + {f'{k}_a{i}': v for k, v in _student_env_metrics.items()}) + mean_env_metrics = _env_metrics + + stats.update(mean_env_metrics) + + if self.n_students == 1: + stats = jax.tree_map(lambda x: x[0], stats) + + if self.n_devices > 1: + stats = jax.tree_map(lambda x: jax.lax.pmean(x, 'device'), stats) + + return stats + + def get_shmap_spec(self): + runner_state_size = len(inspect.signature(self.run).parameters) + in_spec = [P(None, 'device'),]*(runner_state_size) + out_spec = [P(None, 'device'),]*(runner_state_size) + + in_spec[:2] = [P(None),]*2 + in_spec = tuple(in_spec) + out_spec = (P(None),) + in_spec + + return in_spec, out_spec + + @partial(jax.jit, static_argnums=(0,)) + def run( + self, + rng, + train_state, + state, + start_state, + obs, + carry=None, + extra=None, + ep_stats=None): + """ + Perform one update step: rollout all students and teachers + update with PPO + """ + if self.n_devices > 1: + rng = jax.random.fold_in(rng, jax.lax.axis_index('device')) + + rng, *vrngs = jax.random.split(rng, self.n_students+1) + rollout_batch_shape = (self.n_students, self.n_parallel*self.n_eval) + + obs, state, extra = self.benv.reset(jnp.array(vrngs)) + ep_stats = self.rolling_stats.reset_stats( + batch_shape=rollout_batch_shape) + + rollout_start_state = state + + done = jnp.zeros(rollout_batch_shape, dtype=jnp.bool_) + rng, subrng = jax.random.split(rng) + rollout, state, start_state, obs, carry, extra, ep_stats, train_state = \ + self._rollout_students( + subrng, + train_state, + state, + start_state, + obs, + carry, + done, + extra, + ep_stats + ) + + train_batch = self.student_rollout.get_batch( + rollout, + self.student_pop.get_value( + jax.lax.stop_gradient(train_state.params), + obs, + carry, + ) + ) + + # PPOAgent vmaps over the train state and batch. Batch must be N x EM + rng, subrng = jax.random.split(rng) + train_state, update_stats = self.student_pop.update( + subrng, train_state, train_batch) + + # Collect env metrics + if self.track_env_metrics: + env_metrics = self.benv.get_env_metrics(rollout_start_state) + else: + env_metrics = None + + stats = self._compile_stats(update_stats, ep_stats, env_metrics) + stats.update(dict(n_updates=train_state.n_updates[0])) + + train_state = train_state.increment() + self.n_updates += 1 + + return ( + stats, + rng, + train_state, + state, + start_state, + obs, + carry, + extra, + ep_stats + ) diff --git a/src/minimax/runners/eval_runner.py b/src/minimax/runners/eval_runner.py new file mode 100644 index 0000000..7b53ded --- /dev/null +++ b/src/minimax/runners/eval_runner.py @@ -0,0 +1,325 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from typing import Tuple, Optional + +import numpy as np +import jax +import jax.numpy as jnp + +import minimax.envs as envs +from minimax.util.rl import ( + AgentPop, + RollingStats +) +import minimax.util.pytree as _tree_util + + +def generate_all_kwargs_combos(arg_choices): + def update_kwargs_with_choices(prev_combos, k, choices): + updated_combos = [] + for v in choices: + for p in prev_combos: + updated = p.copy() + updated[k] = v + updated_combos.append(updated) + + return updated_combos + + all_combos = [{}] + for k, choices in arg_choices.items(): + all_combos = update_kwargs_with_choices(all_combos, k, choices) + + return all_combos + + +def create_envs_for_kwargs(env_names, kwargs): + # Check for csv kwargs + arg_choices = {} + varied_args = [] + for k, v in kwargs.items(): + if isinstance(v, str) and ',' in v: + vs = eval(v) + arg_choices[k] = vs + varied_args.append(k) + elif isinstance(v, str): + arg_choices[k] = [eval(v)] + else: + arg_choices[k] = [v] + + # List of kwargs + kwargs_combos = generate_all_kwargs_combos(arg_choices) + + env_infos = [] + incl_ext = len(varied_args) > 0 + for name in env_names: + for kwargs in kwargs_combos: + if incl_ext and len(kwargs) > 0: + ext = ':'.join([f'{k}={kwargs[k]}' for k in varied_args]) + ext_name = f'{name}:{ext}' + else: + ext_name = name + env_infos.append( + (name, ext_name, kwargs) + ) + + return env_infos + + +class EvalRunner: + def __init__( + self, + pop, + env_names, + env_kwargs=None, + n_episodes=10, + agent_idxs='*', + render_mode=None): + + self.pop = pop + + if isinstance(agent_idxs, str): + if "*" in agent_idxs: + self.agent_idxs = np.arange(pop.n_agents) + else: + self.agent_idxs = \ + np.array([int(x) for x in agent_idxs.split(',')]) + else: + self.agent_idxs = agent_idxs # assume array + + assert np.max(self.agent_idxs) < pop.n_agents, \ + 'Agent index is out of bounds.' + + if isinstance(env_names, str): + env_names = [ + x.strip() for x in env_names.split(',') + ] + + self.n_episodes = n_episodes + env_infos = create_envs_for_kwargs(env_names, env_kwargs) + env_names = [] + self.ext_env_names = [] + env_kwargs = [] + for (name, ext_name, kwargs) in env_infos: + env_names.append(name) + self.ext_env_names.append(ext_name) + env_kwargs.append(kwargs) + self.n_envs = len(env_names) + + self.benvs = [] + self.env_params = [] + self.env_has_solved_rate = [] + for env_name, kwargs in zip(env_names, env_kwargs): + benv = envs.BatchEnv( + env_name=env_name, + n_parallel=n_episodes, + n_eval=1, + env_kwargs=kwargs, + wrappers=['monitor_return', 'monitor_ep_metrics'] + ) + self.benvs.append(benv) + self.env_params.append(benv.env.params) + self.env_has_solved_rate.append( + benv.env.eval_solved_rate is not None) + + self.action_dtype = self.benvs[0].env.action_space().dtype + + monitored_metrics = self.benvs[0].env.get_monitored_metrics() + self.rolling_stats = RollingStats(names=monitored_metrics, window=1) + self._update_ep_stats = jax.vmap( + jax.vmap( + self.rolling_stats.update_stats, in_axes=(0, 0, 0, None)), + in_axes=(0, 0, 0, None)) + + self.test_return_pre = 'test_return' + self.test_solved_rate_pre = 'test_solved_rate' + + self.render_mode = render_mode + if render_mode: + from minimax.envs.viz.grid_viz import GridVisualizer + self.viz = GridVisualizer() + self.viz.show() + + if render_mode == 'ipython': + from IPython import display + self.ipython_display = display + + def load_checkpoint_state(self, runner_state, state): + runner_state = list(runner_state) + runner_state[1] = runner_state[1].load_state_dict(state[1]) + + return tuple(runner_state) + + @partial(jax.jit, static_argnums=(0, 2)) + def _get_transition( + self, + rng, + benv, + params, + state, + obs, + carry, + zero_carry, + extra): + value, pi_params, next_carry = self.pop.act(params, obs, carry) + pi = self.pop.get_action_dist(pi_params, dtype=self.action_dtype) + rng, subrng = jax.random.split(rng) + action = pi.sample(seed=subrng) + log_pi = pi.log_prob(action) + + rng, *vrngs = jax.random.split(rng, self.pop.n_agents+1) + + step_args = (jnp.array(vrngs), state, action, extra) + (next_obs, + next_state, + reward, + done, + info, + extra) = benv.step(*step_args) + + # Add transition to storage + step = (obs, action, reward, done, log_pi, value) + if carry is not None: + step += (carry,) + + # Zero carry if needed + if carry is not None: + next_carry = jax.vmap(_tree_util.pytree_select)( + done, zero_carry, next_carry) + + if self.render_mode: + self.viz.render( + benv.env.params, + jax.tree_util.tree_map(lambda x: x[0][0], state)) + if self.render_mode == 'ipython': + self.ipython_display.display(self.viz.window.fig) + self.ipython_display.clear_output(wait=True) + + return next_state, next_obs, next_carry, done, info, extra + + @partial(jax.jit, static_argnums=(0, 2)) + def _rollout_benv( + self, + rng, + benv, + params, + env_params, + state, + obs, + carry, + zero_carry, + extra, + ep_stats): + + def _scan_rollout(scan_carry, rng): + (state, + obs, + carry, + extra, + ep_stats) = scan_carry + + step = \ + self._get_transition( + rng, + benv, + params, + state, + obs, + carry, + zero_carry, + extra) + + (next_state, + next_obs, + next_carry, + done, + info, + extra) = step + + ep_stats = self._update_ep_stats(ep_stats, done, info, 1) + + return (next_state, next_obs, next_carry, extra, ep_stats), None + + n_steps = benv.env.max_episode_steps() + rngs = jax.random.split(rng, n_steps) + (state, + obs, + carry, + extra, + ep_stats), _ = jax.lax.scan( + _scan_rollout, + (state, obs, carry, extra, ep_stats), + rngs, + length=n_steps) + + return ep_stats + + @partial(jax.jit, static_argnums=(0,)) + def run(self, rng, params): + """ + Rollout agents on each env. + + For each env, run n_eval episodes in parallel, + where each is indexed to return in order. + """ + eval_stats = self.fake_run(rng, params) + rng, *rollout_rngs = jax.random.split(rng, self.n_envs+1) + for i, (benv, env_param) in enumerate(zip(self.benvs, self.env_params)): + rng, *reset_rngs = jax.random.split(rng, self.pop.n_agents+1) + obs, state, extra = benv.reset(jnp.array(reset_rngs)) + + if self.pop.agent.is_recurrent: + rng, subrng = jax.random.split(rng) + zero_carry = self.pop.init_carry(subrng, obs) + else: + zero_carry = None + + # Reset episodic stats + ep_stats = self.rolling_stats.reset_stats( + batch_shape=(self.pop.n_agents, self.n_episodes)) + + ep_stats = self._rollout_benv( + rollout_rngs[i], + benv, + jax.lax.stop_gradient(params), + env_param, + state, + obs, + zero_carry, + zero_carry, + extra, + ep_stats) + + env_name = self.ext_env_names[i] + mean_return = ep_stats['return'].mean(1) + + if self.env_has_solved_rate[i]: + mean_solved_rate = jax.vmap( + jax.vmap(benv.env.eval_solved_rate))(ep_stats).mean(1) + + for idx in self.agent_idxs: + eval_stats[f'eval/a{idx}:{self.test_return_pre}:{env_name}'] = mean_return[idx].squeeze() + if self.env_has_solved_rate[i]: + eval_stats[f'eval/a{idx}:{self.test_solved_rate_pre}:{env_name}'] = mean_solved_rate[idx].squeeze() + + return eval_stats + + def fake_run(self, rng, params): + eval_stats = {} + for i, env_name in enumerate(self.ext_env_names): + for idx in self.agent_idxs: + eval_stats.update({ + f'eval/a{idx}:{self.test_return_pre}:{env_name}': 0. + }) + if self.env_has_solved_rate[i]: + eval_stats.update({ + f'eval/a{idx}:{self.test_solved_rate_pre}:{env_name}': 0., + }) + + return eval_stats diff --git a/src/minimax/runners/paired_runner.py b/src/minimax/runners/paired_runner.py new file mode 100644 index 0000000..5b73fd3 --- /dev/null +++ b/src/minimax/runners/paired_runner.py @@ -0,0 +1,604 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from enum import Enum +from functools import partial +from typing import Tuple, Optional +import inspect + +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P +import optax +import flax +import flax.linen as nn +from flax.core.frozen_dict import FrozenDict + +import minimax.envs as envs +from minimax.util import pytree as _tree_util +from minimax.util.rl import ( + AgentPop, + VmapTrainState, + RolloutStorage, + RollingStats, + UEDScore, + compute_ued_scores +) + + +class PAIREDRunner: + """ + Orchestrates rollouts across one or more students and teachers. + The main components at play: + - AgentPop: Manages train state and batched inference logic + for a population of agents. + - BatchUEDEnv: Manages environment step and reset logic for a + population of agents batched over a pair of student and + teacher MDPs. + - RolloutStorage: Manages the storing and sampling of collected txns. + - PPO: Handles PPO updates, which take a train state + batch of txns. + """ + + def __init__( + self, + env_name, + env_kwargs, + ued_env_kwargs, + student_agents, + n_students=2, + n_parallel=1, + n_eval=1, + n_rollout_steps=250, + lr=1e-4, + lr_final=None, + lr_anneal_steps=0, + max_grad_norm=0.5, + discount=0.99, + gae_lambda=0.95, + adam_eps=1e-5, + teacher_lr=None, + teacher_lr_final=None, + teacher_lr_anneal_steps=None, + teacher_discount=0.99, + teacher_gae_lambda=0.95, + teacher_agents=None, + ued_score='relative_regret', + track_env_metrics=False, + n_unroll_rollout=1, + render=False, + n_devices=1): + assert n_parallel % n_devices == 0, 'Num envs must be divisible by num devices.' + + ued_score = UEDScore[ued_score.upper()] + + assert len(student_agents) == 1, \ + 'Only one type of student supported.' + assert not (n_students > 2 and ued_score in [UEDScore.RELATIVE_REGRET, UEDScore.MEAN_RELATIVE_REGRET]), \ + 'Standard PAIRED uses only 2 students.' + assert teacher_agents is None or len(teacher_agents) == 1, \ + 'Only one type of teacher supported.' + + self.n_students = n_students + self.n_parallel = n_parallel // n_devices + self.n_eval = n_eval + self.n_devices = n_devices + self.step_batch_size = n_students*n_eval*n_parallel + self.n_rollout_steps = n_rollout_steps + self.n_updates = 0 + self.lr = lr + self.lr_final = lr if lr_final is None else lr_final + self.lr_anneal_steps = lr_anneal_steps + self.teacher_lr = \ + lr if teacher_lr is None else lr + self.teacher_lr_final = \ + self.lr_final if teacher_lr_final is None else teacher_lr_final + self.teacher_lr_anneal_steps = \ + lr_anneal_steps if teacher_lr_anneal_steps is None else teacher_lr_anneal_steps + self.max_grad_norm = max_grad_norm + self.adam_eps = adam_eps + self.ued_score = ued_score + self.track_env_metrics = track_env_metrics + + self.n_unroll_rollout = n_unroll_rollout + self.render = render + + self.student_pop = AgentPop(student_agents[0], n_agents=n_students) + + if teacher_agents is not None: + self.teacher_pop = AgentPop(teacher_agents[0], n_agents=1) + + # This ensures correct partial-episodic bootstrapping by avoiding + # any termination purely due to timeouts. + # env_kwargs.max_episode_steps = self.n_rollout_steps + 1 + self.benv = envs.BatchUEDEnv( + env_name=env_name, + n_parallel=self.n_parallel, + n_eval=n_eval, + env_kwargs=env_kwargs, + ued_env_kwargs=ued_env_kwargs, + wrappers=['monitor_return', 'monitor_ep_metrics'], + ued_wrappers=[] + ) + self.teacher_n_rollout_steps = \ + self.benv.env.ued_max_episode_steps() + + self.student_rollout = RolloutStorage( + discount=discount, + gae_lambda=gae_lambda, + n_steps=n_rollout_steps, + n_agents=n_students, + n_envs=self.n_parallel, + n_eval=self.n_eval, + action_space=self.benv.env.action_space(), + obs_space=self.benv.env.observation_space(), + agent=self.student_pop.agent + ) + + self.teacher_rollout = RolloutStorage( + discount=teacher_discount, + gae_lambda=teacher_gae_lambda, + n_steps=self.teacher_n_rollout_steps, + n_agents=1, + n_envs=self.n_parallel, + n_eval=1, + action_space=self.benv.env.ued_action_space(), + obs_space=self.benv.env.ued_observation_space(), + agent=self.teacher_pop.agent, + ) + + ued_monitored_metrics = ('return',) + self.ued_rolling_stats = RollingStats( + names=ued_monitored_metrics, + window=10, + ) + + monitored_metrics = self.benv.env.get_monitored_metrics() + self.rolling_stats = RollingStats( + names=monitored_metrics, + window=10, + ) + + self._update_ep_stats = jax.vmap( + jax.vmap(self.rolling_stats.update_stats)) + self._update_ued_ep_stats = jax.vmap( + jax.vmap(self.ued_rolling_stats.update_stats)) + + if self.render: + from envs.viz.grid_viz import GridVisualizer + self.viz = GridVisualizer() + self.viz.show() + + def reset(self, rng): + self.n_updates = 0 + + n_parallel = self.n_parallel*self.n_devices + + rng, student_rng, teacher_rng = jax.random.split(rng, 3) + student_info = self._reset_pop( + student_rng, + self.student_pop, + partial(self.benv.reset, sub_batch_size=n_parallel*self.n_eval), + n_parallel_ep=n_parallel*self.n_eval, + lr_init=self.lr, + lr_final=self.lr_final, + lr_anneal_steps=self.lr_anneal_steps) + + teacher_info = self._reset_pop( + teacher_rng, + self.teacher_pop, + partial(self.benv.reset_teacher, n_parallel=n_parallel), + n_parallel_ep=n_parallel, + lr_init=self.teacher_lr, + lr_final=self.teacher_lr_final, + lr_anneal_steps=self.teacher_lr_anneal_steps) + + return ( + rng, + *student_info, + *teacher_info + ) + + def _reset_pop( + self, + rng, + pop, + env_reset_fn, + n_parallel_ep=1, + lr_init=3e-4, + lr_final=3e-4, + lr_anneal_steps=0): + rng, *vrngs = jax.random.split(rng, pop.n_agents+1) + reset_out = env_reset_fn(jnp.array(vrngs)) + if len(reset_out) == 2: + obs, state = reset_out + else: + obs, state, extra = reset_out + dummy_obs = jax.tree_util.tree_map( + lambda x: x[0], obs) # for one agent only + + rng, subrng = jax.random.split(rng) + if pop.agent.is_recurrent: + carry = pop.init_carry(subrng, obs) + else: + carry = None + + rng, subrng = jax.random.split(rng) + params = pop.init_params(subrng, dummy_obs) + + schedule_fn = optax.linear_schedule( + init_value=-float(lr_init), + end_value=-float(lr_final), + transition_steps=lr_anneal_steps, + ) + + tx = optax.chain( + optax.clip_by_global_norm(self.max_grad_norm), + optax.scale_by_adam(eps=self.adam_eps), + optax.scale_by_schedule(schedule_fn), + ) + + train_state = VmapTrainState.create( + apply_fn=pop.agent.evaluate, + params=params, + tx=tx + ) + + ep_stats = self.rolling_stats.reset_stats( + batch_shape=(pop.n_agents, n_parallel_ep)) + + return train_state, state, obs, carry, ep_stats + + def get_checkpoint_state(self, state): + _state = list(state) + _state[1] = state[1].state_dict + _state[6] = state[6].state_dict + + return _state + + def load_checkpoint_state(self, runner_state, state): + runner_state = list(runner_state) + runner_state[1] = runner_state[1].load_state_dict(state[1]) + runner_state[6] = runner_state[6].load_state_dict(state[6]) + + return tuple(runner_state) + + @partial(jax.jit, static_argnums=(0, 2, 3)) + def _get_transition( + self, + rng, + pop, + rollout_mgr, + rollout, + params, + state, + obs, + carry, + done, + reset_state=None, + extra=None): + # Sample action + value, pi_params, next_carry = pop.act(params, obs, carry, done) + pi = pop.get_action_dist(pi_params) + rng, subrng = jax.random.split(rng) + action = pi.sample(seed=subrng) + log_pi = pi.log_prob(action) + + rng, *vrngs = jax.random.split(rng, pop.n_agents+1) + + if pop is self.student_pop: + step_fn = self.benv.step_student + else: + step_fn = self.benv.step_teacher + step_args = (jnp.array(vrngs), state, action) + + if reset_state is not None: # Needed for student to reset to same instance + step_args += (reset_state,) + + if extra is not None: + step_args += (extra,) + next_obs, next_state, reward, done, info, extra = step_fn( + *step_args) + else: + next_obs, next_state, reward, done, info = step_fn(*step_args) + + # Add transition to storage + step = (obs, action, reward, done, log_pi, value) + if carry is not None: + step += (carry,) + + rollout = rollout_mgr.append(rollout, *step) + + if self.render and pop is self.student_pop: + self.viz.render( + self.benv.env.env.params, + jax.tree_util.tree_map(lambda x: x[0][0], state)) + + return rollout, next_state, next_obs, next_carry, done, info, extra + + @partial(jax.jit, static_argnums=(0, 2, 3, 4)) + def _rollout( + self, + rng, + pop, + rollout_mgr, + n_steps, + params, + state, + obs, + carry, + done, + reset_state=None, + extra=None, + ep_stats=None): + rngs = jax.random.split(rng, n_steps) + + rollout = rollout_mgr.reset() + + def _scan_rollout(scan_carry, rng): + (rollout, + state, + obs, + carry, + done, + extra, + ep_stats) = scan_carry + + next_scan_carry = \ + self._get_transition( + rng, + pop, + rollout_mgr, + rollout, + params, + state, + obs, + carry, + done, + reset_state, + extra) + + (rollout, + next_state, + next_obs, + next_carry, + done, + info, + extra) = next_scan_carry + + if ep_stats is not None: + _ep_stats_update_fn = self._update_ep_stats \ + if pop is self.student_pop else self._update_ued_ep_stats + + ep_stats = _ep_stats_update_fn(ep_stats, done, info) + + return (rollout, next_state, next_obs, next_carry, done, extra, ep_stats), None + + (rollout, state, obs, carry, done, extra, ep_stats), _ = jax.lax.scan( + _scan_rollout, + (rollout, state, obs, carry, done, extra, ep_stats), + rngs, + length=n_steps, + unroll=self.n_unroll_rollout + ) + + return rollout, state, obs, carry, extra, ep_stats + + @partial(jax.jit, static_argnums=(0,)) + def _compile_stats(self, + update_stats, ep_stats, + ued_update_stats, ued_ep_stats, + env_metrics=None, + grad_stats=None, ued_grad_stats=None): + mean_returns_by_student = jax.vmap( + lambda x: x.mean())(ep_stats['return']) + mean_returns_by_teacher = jax.vmap( + lambda x: x.mean())(ued_ep_stats['return']) + + mean_ep_stats = jax.vmap(lambda info: jax.tree_map(lambda x: x.mean(), info))( + {k: ep_stats[k] for k in self.rolling_stats.names} + ) + ued_mean_ep_stats = jax.vmap(lambda info: jax.tree_map(lambda x: x.mean(), info))( + {k: ued_ep_stats[k] for k in self.ued_rolling_stats.names} + ) + + student_stats = { + f'mean_{k}': v for k, v in mean_ep_stats.items() + } + student_stats.update(update_stats) + + stats = {} + for i in range(self.n_students): + _student_stats = jax.tree_util.tree_map( + lambda x: x[i], student_stats) # for agent0 + stats.update({f'{k}_a{i}': v for k, v in _student_stats.items()}) + + teacher_stats = { + f'mean_{k}_tch': v for k, v in ued_mean_ep_stats.items() + } + teacher_stats.update({ + f'{k}_tch': v[0] for k, v in ued_update_stats.items() + }) + stats.update(teacher_stats) + + if self.track_env_metrics: + passable_mask = env_metrics.pop('passable') + mean_env_metrics = jax.tree_util.tree_map( + lambda x: (x*passable_mask).sum()/passable_mask.sum(), + env_metrics + ) + mean_env_metrics.update({'passable_ratio': passable_mask.mean()}) + stats.update({ + f'env/{k}': v for k, v in mean_env_metrics.items() + }) + + if self.n_devices > 1: + stats = jax.tree_map(lambda x: jax.lax.pmean(x, 'device'), stats) + + return stats + + def get_shmap_spec(self): + runner_state_size = len(inspect.signature(self.run).parameters) + in_spec = [P(None, 'device'),]*(runner_state_size) + out_spec = [P(None, 'device'),]*(runner_state_size) + + in_spec[:2] = [P(None),]*2 + in_spec[6] = P(None) + in_spec = tuple(in_spec) + out_spec = (P(None),) + in_spec + + return in_spec, out_spec + + @partial(jax.jit, static_argnums=(0,)) + def run( + self, + rng, + train_state, + state, + obs, + carry, + ep_stats, + ued_train_state, + ued_state, + ued_obs, + ued_carry, + ued_ep_stats): + """ + Perform one update step: rollout teacher + students + """ + if self.n_devices > 1: + rng = jax.random.fold_in(rng, jax.lax.axis_index('device')) + + # === Reset teacher env + rollout teacher + rng, *vrngs = jax.random.split(rng, self.teacher_pop.n_agents+1) + ued_reset_out = self.benv.reset_teacher(jnp.array(vrngs)) + if len(ued_reset_out) > 2: + ued_obs, ued_state, ued_extra = ued_reset_out + else: + ued_obs, ued_state = ued_reset_out + ued_extra = None + + # Reset UED ep_stats + if self.ued_rolling_stats is not None: + ued_ep_stats = self.ued_rolling_stats.reset_stats( + batch_shape=(1, self.n_parallel)) + else: + ued_ep_stats = None + + tch_rollout_batch_shape = (1, self.n_parallel*self.n_eval) + done = jnp.zeros(tch_rollout_batch_shape, dtype=jnp.bool_) + rng, subrng = jax.random.split(rng) + ued_rollout, ued_state, ued_obs, ued_carry, _, ued_ep_stats = \ + self._rollout( + subrng, + self.teacher_pop, + self.teacher_rollout, + self.teacher_n_rollout_steps, + jax.lax.stop_gradient(ued_train_state.params), + ued_state, + ued_obs, + ued_carry, + done, + extra=ued_extra, + ep_stats=ued_ep_stats + ) + + # === Reset student to new envs + rollout students + rng, *vrngs = jax.random.split(rng, self.teacher_pop.n_agents+1) + obs, state, extra = jax.tree_util.tree_map( + lambda x: x.squeeze(0), self.benv.reset_student( + jnp.array(vrngs), + ued_state, + self.student_pop.n_agents)) + reset_state = state + + # Reset student ep_stats + st_rollout_batch_shape = (self.n_students, self.n_parallel*self.n_eval) + ep_stats = self.rolling_stats.reset_stats( + batch_shape=st_rollout_batch_shape) + + done = jnp.zeros(st_rollout_batch_shape, dtype=jnp.bool_) + rng, subrng = jax.random.split(rng) + rollout, state, obs, carry, extra, ep_stats = \ + self._rollout( + subrng, + self.student_pop, + self.student_rollout, + self.n_rollout_steps, + jax.lax.stop_gradient(train_state.params), + state, + obs, + carry, + done, + reset_state=reset_state, + extra=extra, + ep_stats=ep_stats) + + # === Update student with PPO + # PPOAgent vmaps over the train state and batch. Batch must be N x EM + student_rollout_last_value = self.student_pop.get_value( + jax.lax.stop_gradient(train_state.params), obs, carry + ) + train_batch = self.student_rollout.get_batch( + rollout, + student_rollout_last_value + ) + + rng, subrng = jax.random.split(rng) + train_state, update_stats = self.student_pop.update( + subrng, train_state, train_batch) + + # === Update teacher with PPO + # - Compute returns per env per agent + # - Compute batched returns based on returns per env per agent + ued_score, _ = compute_ued_scores( + self.ued_score, train_batch, self.n_eval) + ued_rollout = self.teacher_rollout.set_final_reward( + ued_rollout, ued_score) + ued_train_batch = self.teacher_rollout.get_batch( + ued_rollout, + jnp.zeros((1, self.n_parallel)) # Last step terminates episode + ) + + ued_ep_stats = self._update_ued_ep_stats( + ued_ep_stats, + jnp.ones((1, len(ued_score), 1), dtype=jnp.bool_), + {'return': jnp.expand_dims(ued_score, (0, -1))} + ) + + # Update teacher, batch must be 1 x Ex1 + rng, subrng = jax.random.split(rng) + ued_train_state, ued_update_stats = self.teacher_pop.update( + subrng, ued_train_state, ued_train_batch) + + # -------------------------------------------------- + # Collect metrics + if self.track_env_metrics: + env_metrics = self.benv.get_env_metrics(reset_state) + else: + env_metrics = None + + grad_stats, ued_grad_stats = None, None + + stats = self._compile_stats( + update_stats, ep_stats, + ued_update_stats, ued_ep_stats, + env_metrics, + grad_stats, ued_grad_stats) + stats.update(dict(n_updates=train_state.n_updates[0])) + + train_state = train_state.increment() + ued_train_state = ued_train_state.increment() + self.n_updates += 1 + + return ( + stats, + rng, + train_state, state, obs, carry, ep_stats, + ued_train_state, ued_state, ued_obs, ued_carry, ued_ep_stats + ) diff --git a/src/minimax/runners/plr_runner.py b/src/minimax/runners/plr_runner.py new file mode 100644 index 0000000..14a98f2 --- /dev/null +++ b/src/minimax/runners/plr_runner.py @@ -0,0 +1,549 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from enum import Enum +from typing import Tuple, Optional + +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P +import optax +import flax +import flax.linen as nn +from flax.core.frozen_dict import FrozenDict + +import minimax.envs as envs +from minimax.runners.dr_runner import DRRunner +from minimax.util import pytree as _tree_util +from minimax.util.rl import ( + AgentPop, + VmapTrainState, + RolloutStorage, + RollingStats, + UEDScore, + compute_ued_scores, + PopPLRManager +) + + +class MutationCriterion(Enum): + BATCH = 0 + EASY = 1 + HARD = 2 + + +class PLRRunner(DRRunner): + def __init__( + self, + *, + replay_prob=0.5, + buffer_size=100, + staleness_coef=0.3, + use_score_ranks=True, + temp=1.0, + min_fill_ratio=0.5, + use_robust_plr=False, + use_parallel_eval=False, + ued_score='l1_value_loss', + force_unique=False, # Slower if True + mutation_fn=None, + n_mutations=0, + mutation_criterion='batch', + mutation_subsample_size=1, + **kwargs): + use_mutations = mutation_fn is not None + if use_parallel_eval: + replay_prob = 1.0 # Replay every rollout cycle + # Force batch mutations (no UED scores) + mutation_criterion = 'batch' + self._n_parallel_batches = 3 if use_mutations else 2 + kwargs['n_parallel'] *= self._n_parallel_batches + + super().__init__(**kwargs) + + self.replay_prob = replay_prob + self.buffer_size = buffer_size + self.staleness_coef = staleness_coef + self.temp = temp + self.use_score_ranks = use_score_ranks + self.min_fill_ratio = min_fill_ratio + self.use_robust_plr = use_robust_plr + self.use_parallel_eval = use_parallel_eval + self.ued_score = UEDScore[ued_score.upper()] + + self.use_mutations = use_mutations + if self.use_mutations: + self.mutation_fn = envs.get_mutator( + self.benv.env_name, mutation_fn) + else: + self.mutation_fn = None + self.n_mutations = n_mutations + self.mutation_criterion = MutationCriterion[mutation_criterion.upper()] + self.mutation_subsample_size = mutation_subsample_size + + self.force_unique = force_unique + if force_unique: + self.comparator_fn = envs.get_comparator(self.benv.env_name) + else: + self.comparator_fn = None + + if mutation_fn is not None and mutation_criterion != 'batch': + assert self.n_parallel % self.mutation_subsample_size == 0, \ + 'Number of parallel envs must be divisible by mutation subsample size.' + + def reset(self, rng): + runner_state = list(super().reset(rng)) + rng = runner_state[0] + runner_state[0], subrng = jax.random.split(rng) + example_state = self.benv.env.reset(rng)[1] + + self.plr_mgr = PopPLRManager( + n_agents=self.n_students, + example_level=example_state, + ued_score=self.ued_score, + replay_prob=self.replay_prob, + buffer_size=self.buffer_size, + staleness_coef=self.staleness_coef, + temp=self.temp, + use_score_ranks=self.use_score_ranks, + min_fill_ratio=self.min_fill_ratio, + use_robust_plr=self.use_robust_plr, + use_parallel_eval=self.use_parallel_eval, + comparator_fn=self.comparator_fn, + n_devices=self.n_devices + ) + plr_buffer = self.plr_mgr.reset(self.n_students) + + train_state = runner_state[1] + train_state = train_state.replace(plr_buffer=plr_buffer) + if self.n_devices == 1: + runner_state[1] = train_state + else: + plr_buffer = jax.tree_map(lambda x: x.repeat( + self.n_devices, 1), plr_buffer) # replicate plr buffer + # Return PLR buffer directly to make shmap easier + runner_state += (plr_buffer,) + + self.dummy_eval_output = self._create_dummy_eval_output(train_state) + + return tuple(runner_state) + + def _create_dummy_eval_output(self, train_state): + rng, * \ + vrngs = jax.random.split(jax.random.PRNGKey(0), self.n_students+1) + obs, state, extra = self.benv.reset(jnp.array(vrngs)) + + ep_stats = self.rolling_stats.reset_stats( + batch_shape=(self.n_students, self.n_parallel*self.n_eval)) + + ued_scores = jnp.zeros((self.n_students, self.n_parallel)) + + if self.student_pop.agent.is_recurrent: + carry = self.zero_carry + else: + carry = None + rollout = self.student_rollout.reset() + + value, _ = self.student_pop.get_value( + jax.lax.stop_gradient(train_state.params), + obs, + carry, + ) + train_batch = self.student_rollout.get_batch( + rollout, value + ) + + return ( + rng, + train_state, + state, + state, + obs, + carry, + extra, + ep_stats, + state, + train_batch, + ued_scores + ) + + @partial(jax.jit, static_argnums=(0, 8)) + def _eval_and_update_plr( + self, + rng, + levels, + level_idxs, + train_state, + update_plr, + parent_idxs=None, + dupe_mask=None, + fake=False): + # Collect rollout and optionally update plr buffer + # Returns train_batch and ued_scores + if fake: + dummy_eval_output = list(self.dummy_eval_output) + dummy_eval_output[1] = train_state + return tuple(dummy_eval_output) + + rollout_batch_shape = (self.n_students, self.n_parallel*self.n_eval) + obs, state, extra = self.benv.set_state(levels) + ep_stats = self.rolling_stats.reset_stats( + batch_shape=rollout_batch_shape) + + rollout_start_state = state + + done = jnp.zeros(rollout_batch_shape, dtype=jnp.bool_) + if self.student_pop.agent.is_recurrent: + carry = self.zero_carry + else: + carry = None + + rng, subrng = jax.random.split(rng) + start_state = state + rollout, state, start_state, obs, carry, extra, ep_stats, train_state = \ + self._rollout_students( + subrng, + train_state, + state, + start_state, + obs, + carry, + done, + extra, + ep_stats + ) + + value, _ = self.student_pop.get_value( + jax.lax.stop_gradient(train_state.params), + obs, + carry + ) + train_batch = self.student_rollout.get_batch(rollout, value) + + # Update PLR buffer + if self.ued_score == UEDScore.MAX_MC: + max_returns = jax.vmap(lambda x, y: x.at[y].get())( + train_state.plr_buffer.max_returns, level_idxs) + max_returns = jnp.where( + jnp.greater_equal(level_idxs, 0), + max_returns, + jnp.full_like(max_returns, -jnp.inf) + ) + ued_info = {'max_returns': max_returns} + else: + ued_info = None + ued_scores, ued_score_info = compute_ued_scores( + self.ued_score, train_batch, self.n_eval, info=ued_info, ignore_val=-jnp.inf, per_agent=True) + next_plr_buffer = self.plr_mgr.update( + train_state.plr_buffer, + levels=levels, + level_idxs=level_idxs, + ued_scores=ued_scores, + dupe_mask=dupe_mask, + info=ued_score_info, + ignore_val=-jnp.inf, + parent_idxs=parent_idxs) + + next_plr_buffer = jax.vmap( + lambda update, new, prev: jax.tree_map( + lambda x, y: jax.lax.select(update, x, y), new, prev) + )(update_plr, next_plr_buffer, train_state.plr_buffer) + + train_state = train_state.replace(plr_buffer=next_plr_buffer) + + return ( + rng, + train_state, + state, + start_state, + obs, + carry, + extra, + ep_stats, + rollout_start_state, + train_batch, + ued_scores, + ) + + @partial(jax.jit, static_argnums=(0,)) + def _mutate_levels(self, rng, levels, level_idxs, ued_scores=None): + if not self.use_mutations: + return levels, level_idxs, jnp.full_like(level_idxs, -1) + + def upsample_levels(levels, level_idxs, subsample_idxs): + subsample_idxs = subsample_idxs.repeat( + self.n_parallel//self.mutation_subsample_size, -1) + parent_idxs = level_idxs.take(subsample_idxs) + levels = jax.vmap( + lambda x, y: jax.tree_map( + lambda _x: jnp.array(_x).take(y, 0), x) + )(levels, parent_idxs) + + return levels, parent_idxs + + if self.mutation_criterion == MutationCriterion.BATCH: + parent_idxs = level_idxs + + if self.mutation_criterion == MutationCriterion.EASY: + _, top_level_idxs = jax.lax.approx_min_k( + ued_scores, self.mutation_subsample_size) + levels, parent_idxs = upsample_levels( + levels, level_idxs, top_level_idxs) + + elif self.mutation_criterion == MutationCriterion.HARD: + _, top_level_idxs = jax.lax.approx_max_k( + ued_scores, self.mutation_subsample_size) + levels, parent_idxs = upsample_levels( + levels, level_idxs, top_level_idxs) + + n_parallel = level_idxs.shape[-1] + vrngs = jax.vmap(lambda subrng: jax.random.split(subrng, n_parallel))( + jax.random.split(rng, self.n_students) + ) + + mutated_levels = jax.vmap( + lambda *args: jax.vmap(self.mutation_fn, + in_axes=(0, None, 0, None))(*args), + in_axes=(0, None, 0, None) + )(vrngs, self.benv.env_params, levels, self.n_mutations) + + # Mutated levels do not have existing idxs in the PLR buffer. + mutated_level_idxs = jnp.full((self.n_students, n_parallel), -1) + + return mutated_levels, mutated_level_idxs, parent_idxs + + def _efficient_grad_update(self, rng, train_state, train_batch, is_replay): + # PPOAgent vmaps over the train state and batch. Batch must be N x EM + skip_grad_update = jnp.logical_and(self.use_robust_plr, ~is_replay) + + if self.n_students == 1: + train_state, stats = jax.lax.cond( + skip_grad_update[0], + partial(self.student_pop.update, fake=True), + self.student_pop.update, + *(rng, train_state, train_batch) + ) + elif self.n_students > 1: # Have to vmap all students + take only students that need updates + _, dummy_stats = jax.vmap( + lambda *_: self.student_pop.agent.get_empty_update_stats())(np.arange(self.n_students)) + _train_state, stats = self.student.update( + rng, train_state, train_batch) + train_state, stats = jax.vmap(lambda cond, x, y: + jax.tree_map(lambda _cond, _x, _y: jax.lax.select(_cond, _x, _y), cond, x, y))( + is_replay, (train_state, stats), (_train_state, dummy_stats) + ) + + return train_state, stats + + @partial(jax.jit, static_argnums=(0,)) + def _compile_stats(self, update_stats, ep_stats, env_metrics=None, plr_stats=None): + stats = super()._compile_stats(update_stats, ep_stats, env_metrics) + + if plr_stats is not None: + plr_stats = jax.vmap(lambda info: jax.tree_map( + lambda x: x.mean(), info))(plr_stats) + + if self.n_students > 1: + _plr_stats = {} + for i in range(self.n_students): + _student_plr_stats = jax.tree_util.tree_map( + lambda x: x[i], plr_stats) # for agent0 + _plr_stats.update( + {f'{k}_a{i}': v for k, v in _student_plr_stats.items()}) + plr_stats = _plr_stats + else: + plr_stats = jax.tree_map(lambda x: x[0], plr_stats) + + stats.update({f'plr_{k}': v for k, v in plr_stats.items()}) + + if self.n_devices > 1: + stats = jax.tree_map(lambda x: jax.lax.pmean(x, 'device'), stats) + + return stats + + @partial(jax.jit, static_argnums=(0,)) + def run( + self, + rng, + train_state, + state, + start_state, + obs, + carry=None, + extra=None, + ep_stats=None, + plr_buffer=None): + # If device sharded, load sharded PLR buffer into train state + if self.n_devices > 1: + rng = jax.random.fold_in(rng, jax.lax.axis_index('device')) + train_state = train_state.replace(plr_buffer=plr_buffer) + + # Sample next training levels via PLR + rng, *vrngs = jax.random.split(rng, self.n_students+1) + obs, state, extra = self.benv.reset( + jnp.array(vrngs), self.n_parallel, 1) + + if self.use_parallel_eval: + n_level_samples = self.n_parallel//self._n_parallel_batches + new_levels = jax.tree_map( + lambda x: x.at[:, n_level_samples:2*n_level_samples].get(), state) + else: + n_level_samples = self.n_parallel + new_levels = state + + rng, subrng = jax.random.split(rng) + levels, level_idxs, is_replay, next_plr_buffer = \ + self.plr_mgr.sample(subrng, train_state.plr_buffer, + new_levels, n_level_samples) + train_state = train_state.replace(plr_buffer=next_plr_buffer) + + # If use_parallel_eval=True, need to combine replay and non-replay levels together + # Need to mutate levels as well + parent_idxs = jnp.full((self.n_students, self.n_parallel), -1) + if self.use_parallel_eval: # Parallel ACCEL + new_level_idxs = jnp.full_like(parent_idxs, -1) + + _all_levels = jax.vmap( + lambda x, y: _tree_util.pytree_merge( + x, y, start_idx=n_level_samples, src_len=n_level_samples), + )(state, levels) + _all_level_idxs = jax.vmap( + lambda x, y: _tree_util.pytree_merge( + x, y, start_idx=n_level_samples, src_len=n_level_samples), + )(new_level_idxs, level_idxs) + + if self.use_mutations: + rng, subrng = jax.random.split(rng) + mutated_levels, mutated_level_idxs, _parent_idxs = self._mutate_levels( + subrng, levels, level_idxs) + + fallback_levels = jax.tree_map( + lambda x: x.at[:, -n_level_samples:].get(), state) + fallback_level_idxs = jnp.full_like(mutated_level_idxs, -1) + + mutated_levels = jax.vmap( + lambda cond, x, y: jax.tree_map( + lambda _x, _y: jax.lax.select(cond, _x, _y), x, y + ))(is_replay, mutated_levels, fallback_levels) + + mutated_level_idxs = jax.vmap( + lambda cond, x, y: jax.tree_map( + lambda _x, _y: jax.lax.select(cond, _x, _y), x, y + ))(is_replay, mutated_level_idxs, fallback_level_idxs) + + _parent_idxs = jax.vmap( + lambda cond, x, y: jax.tree_map( + lambda _x, _y: jax.lax.select(cond, _x, _y), x, y + ))(is_replay, _parent_idxs, fallback_level_idxs) + + mutated_levels_start_idx = 2*n_level_samples + _all_levels = jax.vmap( + lambda x, y: _tree_util.pytree_merge( + x, y, start_idx=mutated_levels_start_idx, src_len=n_level_samples), + )(_all_levels, mutated_levels) + _all_level_idxs = jax.vmap( + lambda x, y: _tree_util.pytree_merge( + x, y, start_idx=mutated_levels_start_idx, src_len=n_level_samples), + )(_all_level_idxs, mutated_level_idxs) + parent_idxs = jax.vmap( + lambda x, y: _tree_util.pytree_merge( + x, y, start_idx=mutated_levels_start_idx, src_len=n_level_samples), + )(parent_idxs, _parent_idxs) + + levels = _all_levels + level_idxs = _all_level_idxs + + # dedupe levels, move into PLR buffer logic + if self.force_unique: + level_idxs, dupe_mask = self.plr_mgr.dedupe_levels( + next_plr_buffer, levels, level_idxs) + else: + dupe_mask = None + + # Evaluate levels + update PLR + result = self._eval_and_update_plr( + rng, levels, level_idxs, train_state, update_plr=jnp.array([True]*self.n_students), parent_idxs=parent_idxs, dupe_mask=dupe_mask) + rng, train_state, state, start_state, obs, carry, extra, ep_stats, \ + rollout_start_state, train_batch, ued_scores = result + + if self.use_parallel_eval: + replay_start_idx = self.n_eval*n_level_samples + replay_end_idx = 2*replay_start_idx + train_batch = jax.vmap( + lambda x: jax.tree_map( + lambda _x: _x.at[:, replay_start_idx:replay_end_idx].get(), x) + )(train_batch) + + # Gradient update + rng, subrng = jax.random.split(rng) + train_state, update_stats = self._efficient_grad_update( + subrng, train_state, train_batch, is_replay) + + # Mutation step + use_mutations = jnp.logical_and(self.use_mutations, is_replay) + # Already mutated above in parallel + use_mutations = jnp.logical_and( + use_mutations, not self.use_parallel_eval) + rng, arng, brng = jax.random.split(rng, 3) + + mutated_levels, mutated_level_idxs, parent_idxs = jax.lax.cond( + use_mutations.any(), + self._mutate_levels, + lambda *_: (levels, level_idxs, jnp.full_like(level_idxs, -1)), + *(arng, levels, level_idxs, ued_scores) + ) + + mutated_dupe_mask = jnp.zeros_like(mutated_level_idxs, dtype=jnp.bool_) + if self.force_unique: # Should move into update plr logic + mutated_level_idxs, mutated_dupe_mask = jax.lax.cond( + use_mutations.any(), + self.plr_mgr.dedupe_levels, + lambda *_: (mutated_level_idxs, mutated_dupe_mask), + *(next_plr_buffer, mutated_levels, mutated_level_idxs) + ) + + mutation_eval_result = jax.lax.cond( + use_mutations.any(), + self._eval_and_update_plr, + partial(self._eval_and_update_plr, fake=True), + *(brng, mutated_levels, mutated_level_idxs, train_state, use_mutations, parent_idxs, mutated_dupe_mask) + ) + train_state = mutation_eval_result[1] + + # Collect training env metrics + if self.track_env_metrics: + env_metrics = self.benv.get_env_metrics(levels) + else: + env_metrics = None + + plr_stats = self.plr_mgr.get_metrics(train_state.plr_buffer) + + stats = self._compile_stats( + update_stats, ep_stats, env_metrics, plr_stats) + + if self.n_devices > 1: + plr_buffer = train_state.plr_buffer + train_state = train_state.replace(plr_buffer=None) + + train_state = train_state.increment() + stats.update(dict(n_updates=train_state.n_updates[0])) + + return ( + stats, + rng, + train_state, + state, + start_state, + obs, + carry, + extra, + ep_stats, + plr_buffer + ) diff --git a/src/minimax/runners/xp_runner.py b/src/minimax/runners/xp_runner.py new file mode 100644 index 0000000..55c23d7 --- /dev/null +++ b/src/minimax/runners/xp_runner.py @@ -0,0 +1,310 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy +from functools import partial +from collections import defaultdict +import time + +import numpy as np +import jax +from jax.sharding import Mesh, PartitionSpec as P +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map + +from .eval_runner import EvalRunner +from .dr_runner import DRRunner +from .paired_runner import PAIREDRunner +from .plr_runner import PLRRunner +from minimax.util.rl import UEDScore, PopPLRManager +import minimax.envs as envs +import minimax.models as models +import minimax.agents as agents + + +class RunnerInfo: + def __init__( + self, + runner_cls, + is_ued=False): + self.runner_cls = runner_cls + self.is_ued = is_ued + + +RUNNER_INFO = { + 'dr': RunnerInfo( + runner_cls=DRRunner, + ), + 'plr': RunnerInfo( + runner_cls=PLRRunner, + ), + 'paired': RunnerInfo( + runner_cls=PAIREDRunner, + is_ued=True + ) +} + + +class ExperimentRunner: + def __init__( + self, + train_runner, + env_name, + agent_rl_algo, + student_model_name, + student_critic_model_name=None, + student_agent_kind="ppo", + teacher_model_name=None, + train_runner_kwargs={}, + env_kwargs={}, + ued_env_kwargs={}, + student_rl_kwargs={}, + teacher_rl_kwargs={}, + student_model_kwargs={}, + teacher_model_kwargs={}, + eval_kwargs={}, + eval_env_kwargs={}, + shaped_reward_steps=None, + n_devices=1, + xpid=None, + ): + self.env_name = env_name + self.agent_rl_algo = agent_rl_algo + self.is_ued = RUNNER_INFO[train_runner].is_ued + self.xpid = xpid + + dummy_env = envs.make( + env_name, + env_kwargs, + ued_env_kwargs)[0] + + # ---- Make agent ---- + student_model_kwargs['output_dim'] = dummy_env.action_space().n + student_model = models.make( + env_name=env_name, + model_name=student_model_name, + **student_model_kwargs + ) + + if student_agent_kind == "ppo": + student_agent = agents.PPOAgent( + model=student_model, + n_devices=n_devices, + **student_rl_kwargs + ) + else: + raise ValueError( + "Unknown student_agent_kind: {}".format(student_agent_kind)) + + # ---- Handle UED-related settings ---- + if self.is_ued: + max_teacher_steps = dummy_env.ued_max_episode_steps() + teacher_model_kwargs['n_scalar_embeddings'] = max_teacher_steps + teacher_model_kwargs['max_scalar'] = max_teacher_steps + teacher_model_kwargs['output_dim'] = dummy_env.ued_action_space().n + + teacher_model = models.make( + env_name=env_name, + model_name=teacher_model_name, + **teacher_model_kwargs + ) + + teacher_agent = agents.PPOAgent( + model=teacher_model, + n_devices=n_devices, + **teacher_rl_kwargs + ) + + train_runner_kwargs.update(dict( + teacher_agents=[teacher_agent] + )) + train_runner_kwargs.update(dict( + ued_env_kwargs=ued_env_kwargs + )) + + # Debug, tabulate student and teacher model + # import jax.numpy as jnp + # dummy_rng = jax.random.PRNGKey(0) + # obs, _ = dummy_env.reset(dummy_rng) + # hx = student_model.initialize_carry(dummy_rng, (1,)) + # ued_obs, _ = dummy_env.reset_teacher(dummy_rng) + # ued_hx = teacher_model.initialize_carry(dummy_rng, (1,)) + + # obs['image'] = jnp.expand_dims(obs['image'], 0) + # ued_obs['image'] = jnp.expand_dims(ued_obs['image'], 0) + + # print(student_model.tabulate(dummy_rng, obs, hx)) + # print(teacher_model.tabulate(dummy_rng, ued_obs, hx)) + + # import pdb; pdb.set_trace() + + # ---- Set up train runner ---- + runner_cls = RUNNER_INFO[train_runner].runner_cls + + # Set up learning rate annealing parameters + lr_init = train_runner_kwargs.lr + lr_final = train_runner_kwargs.lr_final + lr_anneal_steps = train_runner_kwargs.lr_anneal_steps + + if lr_final is None: + train_runner_kwargs.lr_final = lr_init + if train_runner_kwargs.lr_final == train_runner_kwargs.lr: + train_runner_kwargs.lr_anneal_steps = 0 + + self.runner = runner_cls( + env_name=env_name, + env_kwargs=env_kwargs, + student_agents=[student_agent], + n_devices=n_devices, + **train_runner_kwargs) + + # ---- Make eval runner ---- + if eval_kwargs.get('env_names') is None: + self.eval_runner = None + else: + self.eval_runner = EvalRunner( + pop=self.runner.student_pop, + env_kwargs=eval_env_kwargs, + **eval_kwargs) + + self._start_tick = 0 + + # ---- Set up device parallelism ---- + self.n_devices = n_devices + if n_devices > 1: + dummy_runner_state = self.reset_train_runner(jax.random.PRNGKey(0)) + self._shmap_run = self._make_shmap_run(dummy_runner_state) + else: + self._shmap_run = None + + @partial(jax.jit, static_argnums=(0,)) + def step(self, runner_state, evaluate=False): + if self.n_devices > 1: + run_fn = self._shmap_run + else: + run_fn = self.runner.run + + stats, *runner_state = run_fn(*runner_state) + + rng = runner_state[0] + rng, subrng = jax.random.split(rng) + + if self.eval_runner is not None: + params = runner_state[1].params + eval_stats = jax.lax.cond( + evaluate, + self.eval_runner.run, + self.eval_runner.fake_run, + *(subrng, params) + ) + else: + eval_stats = {} + + return stats, eval_stats, rng, *runner_state[1:] + + def _make_shmap_run(self, runner_state): + devices = mesh_utils.create_device_mesh((self.n_devices,)) + mesh = Mesh(devices, axis_names=('device')) + + in_specs, out_specs = self.runner.get_shmap_spec() + + return partial(shard_map, + mesh=mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False + )(self.runner.run) + + def train( + self, + rng, + agent_algo='ppo', + algo_runner='dr', + n_total_updates=1000, + logger=None, + log_interval=1, + test_interval=1, + checkpoint_interval=0, + archive_interval=0, + archive_init_checkpoint=False, + from_last_checkpoint=False + ): + """ + Entry-point for training + """ + # Load checkpoint if any + runner_state = self.runner.reset(rng) + + if from_last_checkpoint: + last_checkpoint_state = logger.load_last_checkpoint_state() + if last_checkpoint_state is not None: + runner_state = self.runner.load_checkpoint_state( + runner_state, + last_checkpoint_state + ) + self._start_tick = runner_state[1].n_iters[0] + + # Archive initialization weights if necessary + if archive_init_checkpoint: + logger.checkpoint( + self.runner.get_checkpoint_state(runner_state), + index=0, + archive_interval=1) + + # Train loop + log_on = logger is not None and log_interval > 0 + checkpoint_on = checkpoint_interval > 0 or archive_interval > 0 + train_state = runner_state[1] + + tick = self._start_tick + train_steps = tick*self.runner.step_batch_size * \ + self.runner.n_rollout_steps*self.n_devices + real_train_steps = train_steps//self.runner.n_students + + while (train_state.n_updates < n_total_updates).any(): + evaluate = test_interval > 0 and (tick+1) % test_interval == 0 + + start = time.time() + stats, eval_stats, *runner_state = \ + self.step(runner_state, evaluate) + end = time.time() + + if evaluate: + stats.update(eval_stats) + else: + stats.update({k: None for k in eval_stats.keys()}) + + train_state = runner_state[1] + + dsteps = self.runner.step_batch_size*self.runner.n_rollout_steps*self.n_devices + real_dsteps = dsteps//self.runner.n_students + train_steps += dsteps + real_train_steps += real_dsteps + sps = int(dsteps/(end-start)) + real_sps = int(real_dsteps/(end-start)) + stats.update(dict( + steps=train_steps, + sps=sps, + real_steps=real_train_steps, + real_sps=real_sps + )) + + tick += 1 + + if log_on and tick % log_interval == 0: + logger.log(stats, tick, ignore_val=-np.inf) + + if checkpoint_on and tick > 0: + if tick % checkpoint_interval == 0 \ + or (archive_interval > 0 and tick % archive_interval == 0): + checkpoint_state = \ + self.runner.get_checkpoint_state(runner_state) + logger.checkpoint( + checkpoint_state, + index=tick, + archive_interval=archive_interval) diff --git a/src/minimax/runners_ma/__init__.py b/src/minimax/runners_ma/__init__.py new file mode 100644 index 0000000..e5058e6 --- /dev/null +++ b/src/minimax/runners_ma/__init__.py @@ -0,0 +1,24 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .xp_runner import ExperimentRunner +from .eval_runner import EvalRunner +from .eval_runner_heterogenous import EvalRunnerHeterogenous +from .dr_runner import DRRunner +from .plr_runner import PLRRunner +from .paired_runner import PAIREDRunner + + +__all__ = [ + ExperimentRunner, + EvalRunner, + EvalRunnerHeterogenous, + DRRunner, + PLRRunner, + PAIREDRunner +] diff --git a/src/minimax/runners_ma/dr_runner.py b/src/minimax/runners_ma/dr_runner.py new file mode 100644 index 0000000..f3784f3 --- /dev/null +++ b/src/minimax/runners_ma/dr_runner.py @@ -0,0 +1,569 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from typing import Dict, Tuple, Optional +import inspect + +import chex +import einops +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P +import optax +import flax +import flax.linen as nn +from flax.core.frozen_dict import FrozenDict +from torch import NoneType + +import minimax.envs as envs +from minimax.util import pytree as _tree_util +from minimax.util.rl import ( + AgentPop, + VmapMAPPOTrainState, + RolloutStorageSeperate, + RollingStats +) + + +class DRRunner: + """ + Orchestrates rollouts across one or more students. + The main components at play: + - AgentPop: Manages train state and batched inference logic + for a population of agents. + - BatchEnv: Manages environment step and reset logic, using a + populaton of agents. + - RolloutStorage: Manages the storing and sampling of collected txns. + - PPO: Handles PPO updates, which take a train state + batch of txns. + """ + + def __init__( + self, + env_name, + env_kwargs, + student_agents, + student_agent_kind, + n_students=1, + n_parallel=1, + n_eval=1, + n_rollout_steps=256, + lr=1e-4, + lr_final=None, + lr_anneal_steps=0, + max_grad_norm=0.5, + discount=0.99, + gae_lambda=0.95, + adam_eps=1e-5, + normalize_return=False, + track_env_metrics=False, + n_unroll_rollout=1, + n_devices=1, + render=False, + shaped_reward=False, + ): + + assert len(student_agents) == 1, 'Only one type of student supported.' + assert n_parallel % n_devices == 0, 'Num envs must be divisible by num devices.' + + self.n_students = n_students + self.n_parallel = n_parallel // n_devices + self.n_eval = n_eval + self.n_devices = n_devices + self.step_batch_size = n_students*n_eval*n_parallel + self.n_rollout_steps = n_rollout_steps + self.n_updates = 0 + self.lr = lr + self.lr_final = lr if lr_final is None else lr_final + self.lr_anneal_steps = lr_anneal_steps + self.max_grad_norm = max_grad_norm + self.adam_eps = adam_eps + self.normalize_return = normalize_return + self.track_env_metrics = track_env_metrics + self.n_unroll_rollout = n_unroll_rollout + self.render = render + + self.shaped_reward = shaped_reward + + self.student_agent_kind = student_agent_kind + self.student_pop = AgentPop(student_agents[0], n_agents=n_students) + + self.env, self.env_params = envs.make( + env_name, + env_kwargs=env_kwargs + ) + self._action_shape = self.env.action_space().shape + + wrappers_lst = ['monitor_return', 'monitor_ep_metrics'] + if self.student_agent_kind == "mappo": + wrappers_lst = ['world_state_wrapper'] + wrappers_lst + + self.benv = envs.BatchEnv( + env_name=env_name, + n_parallel=self.n_parallel, + n_eval=self.n_eval, + env_kwargs=env_kwargs, + wrappers=wrappers_lst, + ) + self.action_dtype = self.benv.env.action_space().dtype + + self.student_rollout = RolloutStorageSeperate( + discount=discount, + gae_lambda=gae_lambda, + n_steps=n_rollout_steps, + n_agents=n_students, + n_envs=self.n_parallel, + n_eval=self.n_eval, + action_space=self.env.action_space(), + obs_space=self.env.observation_space(), + obs_space_shared_shape=self.benv.env.world_state_size(), + agent=self.student_pop.agent, + ) + + monitored_metrics = self.benv.env.get_monitored_metrics() + self.rolling_stats = RollingStats( + names=monitored_metrics, + window=10, + ) + self._update_ep_stats = jax.vmap( + jax.vmap(self.rolling_stats.update_stats)) + + if self.render: + from minimax.envs.viz.grid_viz import GridVisualizer + self.viz = GridVisualizer() + self.viz.show() + + def reset(self, rng): + self.n_updates = 0 + + n_parallel = self.n_parallel*self.n_devices + + rngs, *vrngs = jax.random.split(rng, self.n_students+1) + obs, state, extra = self.benv.reset( + jnp.array(vrngs), n_parallel=n_parallel) + + # dummy_obs = jax.tree_util.tree_map(lambda x: x[0], obs) # for one agent only + dummy_obs = self._concat_multi_agent_obs(obs) + dummy_shared_obs = self._concat_multi_agent_obs(obs['world_state']) + + rng, subrng = jax.random.split(rng) + if self.student_pop.agent.is_recurrent: + actor_carry, critic_carry = self.student_pop.init_carry( + subrng, dummy_obs) + self.zero_carry = jax.tree_map( + lambda x: x.at[:, :self.n_parallel].get(), actor_carry) + else: + actor_carry, critic_carry = None, None + + rng, subrng = jax.random.split(rng) + actor_params, critic_params = self.student_pop.init_params( + subrng, (dummy_obs[0], dummy_shared_obs[0])) + + schedule_fn = optax.linear_schedule( + init_value=-float(self.lr), + end_value=-float(self.lr_final), + transition_steps=self.lr_anneal_steps, + ) + + tx_actor = optax.chain( + optax.clip_by_global_norm(self.max_grad_norm), + optax.adam(learning_rate=float(self.lr), eps=self.adam_eps) + ) + + tx_critic = optax.chain( + optax.clip_by_global_norm(self.max_grad_norm), + optax.adam(learning_rate=float(self.lr), eps=self.adam_eps) + ) + + shaped_reward_coeff_value = 1.0 if self.shaped_reward else 0.0 + shaped_reward_coeff = jnp.full( + (self.n_students, 1), fill_value=shaped_reward_coeff_value) + train_state = VmapMAPPOTrainState.create( + actor_apply_fn=self.student_pop.agent.evaluate_action, + actor_params=actor_params, + actor_tx=tx_actor, + critic_apply_fn=self.student_pop.agent.get_value, + critic_params=critic_params, + critic_tx=tx_critic, + shaped_reward_coeff=shaped_reward_coeff, + ) + + ep_stats = self.rolling_stats.reset_stats( + batch_shape=(self.n_students, n_parallel*self.n_eval)) + + start_state = state + + return ( + rng, + train_state, + state, + start_state, # Used to track metrics from starting state + obs, + actor_carry, + critic_carry, + extra, + ep_stats + ) + + def get_checkpoint_state(self, state): + _state = list(state) + _state[1] = state[1].state_dict + + return _state + + def load_checkpoint_state(self, runner_state, state): + runner_state = list(runner_state) + runner_state[1] = runner_state[1].load_state_dict(state[1]) + + return tuple(runner_state) + + @partial(jax.jit, static_argnums=(0, 2)) + def _get_transition( + self, + rng, + pop, + actor_params, + critic_params, + rollout, + state, + start_state, + obs, + actor_carry, + critic_carry, + done, + extra=None): + # Sample action + + ma_obs = self._concat_multi_agent_obs(obs) + + # PRINT THE CURRENT STATE + + _, pi_params, next_actor_carry = jax.vmap(pop.act, in_axes=(None, 2, 2, None))( + actor_params, ma_obs, actor_carry, done) + next_actor_carry = jax.tree_map(lambda x: einops.rearrange( + x, 't n a d -> a t n d'), next_actor_carry) + shared_obs = self._concat_multi_agent_obs(obs['world_state']) + value, next_critic_carry = jax.vmap(pop.get_value, in_axes=(None, 2, 2, None))( + critic_params, shared_obs, critic_carry, done) + next_critic_carry = jax.tree_map(lambda x: einops.rearrange( + x, 't n a d -> a t n d'), next_critic_carry) + + pi = pop.get_action_dist(pi_params, dtype=self.action_dtype) + rng, subrng = jax.random.split(rng) + action = pi.sample(seed=subrng) + log_pi = pi.log_prob(action) + + env_action = { + 'agent_0': action[0], + 'agent_1': action[1] + } + + rng, *vrngs = jax.random.split(rng, self.n_students+1) + (next_obs, + next_state, + reward, + done, + info, + extra) = self.benv.step(jnp.array(vrngs), state, env_action, extra) + + # jax.debug.print("Current state (r: {r}, sparse: {spa}, shaped: {sha}) =\n{a}", spa=info["sparse_reward"][0, 0].mean(), sha=info["shaped_reward"][0, 0].mean(), r=reward[0, 0], a=ma_obs[0, 0, 0, :, :, 0] + # * 1 + ma_obs[0, 0, 0, :, :, 1]*2+ma_obs[0, 0, 0, :, :, 11]*3) + + next_start_state = jax.vmap(_tree_util.pytree_select)( + done, next_state, start_state + ) + + # Add transition to storage + log_pi = einops.rearrange(log_pi, 'a s n -> s n a') + value = einops.rearrange(value, 'a s n -> s n a') + + action = einops.rearrange(action, 'a s n -> s n a') + + done_ = jnp.concatenate( + [done[:, :, jnp.newaxis], done[:, :, jnp.newaxis]], axis=2) + + # jax.debug.print("sparse reward = {b}, reward = {c}", # a=info["shaped_reward"].mean(), + # b=info["sparse_reward"].mean(), c=reward.mean()) + step = (ma_obs, shared_obs, action, info["sparse_reward"], + info["shaped_reward"], done_, log_pi, value) + if actor_carry is not None and critic_carry is not None: + step += (actor_carry, critic_carry) + + rollout = self.student_rollout.append(rollout, *step) + + if self.render: + self.viz.render( + self.benv.env.params, + jax.tree_util.tree_map(lambda x: x[0][0], state)) + + return ( + rollout, + next_state, + next_start_state, + next_obs, + jax.tree_map(lambda x: einops.rearrange( + x, 'n a s d -> s n a d'), next_actor_carry), + jax.tree_map(lambda x: einops.rearrange( + x, 'n a s d -> s n a d'), next_critic_carry), + done, + info, + extra + ) + + @partial(jax.jit, static_argnums=(0,)) + def _rollout_students( + self, + rng, + train_state, + state, + start_state, + obs, + actor_carry, + critic_carry, + done, + extra=None, + ep_stats=None): + rollout = self.student_rollout.reset() + + rngs = jax.random.split(rng, self.n_rollout_steps) + + def _scan_rollout(scan_carry, rng): + rollout, state, start_state, obs, actor_carry, critic_carry, done, extra, ep_stats, train_state = scan_carry + + next_scan_carry = \ + self._get_transition( + rng, + self.student_pop, + jax.lax.stop_gradient(train_state.actor_params), + jax.lax.stop_gradient(train_state.critic_params), + rollout, + state, + start_state, + obs, + actor_carry, + critic_carry, + done, + extra) + (rollout, + next_state, + next_start_state, + next_obs, + next_actor_carry, + next_critic_carry, + done, + info, + extra) = next_scan_carry + + ep_stats = self._update_ep_stats(ep_stats, done, info) + + return ( + rollout, + next_state, + next_start_state, + next_obs, + next_actor_carry, + next_critic_carry, + done, + extra, + ep_stats, + train_state), None + + (rollout, + state, + start_state, + obs, + actor_carry, + critic_carry, + done, + extra, + ep_stats, + train_state), _ = jax.lax.scan( + _scan_rollout, + (rollout, + state, + start_state, + obs, + actor_carry, + critic_carry, + done, + extra, + ep_stats, + train_state), + rngs, + length=self.n_rollout_steps, + ) + + return rollout, state, start_state, obs, actor_carry, critic_carry, extra, ep_stats, train_state + + @partial(jax.jit, static_argnums=(0,)) + def _compile_stats(self, update_stats, ep_stats, env_metrics=None, shaped_reward_coeff=None): + + info = {k: ep_stats[k] for k in self.rolling_stats.names} + + stats = jax.vmap(lambda info: jax.tree_map(lambda x: x.mean(), info))( + info + ) + + if shaped_reward_coeff is not None: + update_stats.update( + {"shaped_reward_coeff": shaped_reward_coeff}) + + stats.update(update_stats) + + if self.n_students > 1: + _stats = {} + for i in range(self.n_students): + _student_stats = jax.tree_util.tree_map( + lambda x: x[i], stats) # for agent0 + _stats.update( + {f'a{i}/{k}': v for k, v in _student_stats.items()}) + stats = _stats + + if self.track_env_metrics: + mean_env_metrics = jax.vmap(lambda info: jax.tree_map( + lambda x: x.mean(), info))(env_metrics) + mean_env_metrics = {f'env/{k}': v for k, + v in mean_env_metrics.items()} + + if self.n_students > 1: + _env_metrics = {} + for i in range(self.n_students): + _student_env_metrics = jax.tree_util.tree_map( + lambda x: x[i], mean_env_metrics) # for agent0 + _env_metrics.update( + {f'{k}_a{i}': v for k, v in _student_env_metrics.items()}) + mean_env_metrics = _env_metrics + + stats.update(mean_env_metrics) + + if self.n_students == 1: + stats = jax.tree_map(lambda x: x[0], stats) + + if self.n_devices > 1: + stats = jax.tree_map(lambda x: jax.lax.pmean(x, 'device'), stats) + + return stats + + def get_shmap_spec(self): + runner_state_size = len(inspect.signature(self.run).parameters) + in_spec = [P(None, 'device'),]*(runner_state_size) + out_spec = [P(None, 'device'),]*(runner_state_size) + + in_spec[:2] = [P(None),]*2 + in_spec = tuple(in_spec) + out_spec = (P(None),) + in_spec + + return in_spec, out_spec + + @partial(jax.jit, static_argnums=(0,)) + def run( + self, + rng, + train_state, + state, + start_state, + obs, + actor_carry=None, + critic_carry=None, + extra=None, + ep_stats=None): + """ + Perform one update step: rollout all students and teachers + update with PPO + """ + if self.n_devices > 1: + rng = jax.random.fold_in(rng, jax.lax.axis_index('device')) + + rng, *vrngs = jax.random.split(rng, self.n_students+1) + rollout_batch_shape = (self.n_students, self.n_parallel*self.n_eval) + + obs, state, extra = self.benv.reset(jnp.array(vrngs)) + ep_stats = self.rolling_stats.reset_stats( + batch_shape=rollout_batch_shape) + + rollout_start_state = state + + done = jnp.zeros(rollout_batch_shape, dtype=jnp.bool_) + rng, subrng = jax.random.split(rng) + rollout, state, start_state, obs, actor_carry, critic_carry, extra, ep_stats, train_state = \ + self._rollout_students( + subrng, + train_state, + state, + start_state, + obs, + actor_carry, + critic_carry, + done, + extra, + ep_stats + ) + + reward = rollout["rewards"].sum(axis=1).mean(-1)[:, :, jnp.newaxis] + shaped_reward = rollout["shaped_rewards"].sum( + axis=1).mean(-1)[:, :, jnp.newaxis] + + ep_stats["reward"] = reward + ep_stats["shaped_reward"] = shaped_reward + ep_stats["shaped_reward_scaled_by_shaped_reward_coeff"] = shaped_reward * \ + train_state.shaped_reward_coeff + ep_stats["reward_p_shaped_reward_scaled"] = reward + shaped_reward * \ + train_state.shaped_reward_coeff + + shared_obs = self._concat_multi_agent_obs(obs['world_state']) + value, _ = jax.vmap(self.student_pop.get_value, in_axes=(None, 2, 2))( + jax.lax.stop_gradient(train_state.critic_params), + shared_obs, + critic_carry + ) + + value = einops.rearrange( + value, "n_env_agents n_students n_parallel -> n_students n_parallel n_env_agents") + train_batch = self.student_rollout.get_batch( + rollout, + value, + train_state.shaped_reward_coeff + ) + + # PPOAgent vmaps over the train state and batch. Batch must be N x EM + rng, subrng = jax.random.split(rng) + train_state, update_stats = self.student_pop.update( + subrng, train_state, train_batch) + + # Collect env metrics + if self.track_env_metrics: + env_metrics = self.benv.get_env_metrics(rollout_start_state) + else: + env_metrics = None + + stats = self._compile_stats( + update_stats, ep_stats, env_metrics, shaped_reward_coeff=train_state.shaped_reward_coeff) + stats.update(dict(n_updates=train_state.n_updates[0])) + + train_state = train_state.increment() + self.n_updates += 1 + + return ( + stats, + rng, + train_state, + state, + start_state, + obs, + actor_carry, + critic_carry, + extra, + ep_stats, + rollout_start_state + ) + + def _concat_multi_agent_obs(self, obs: Dict[str, chex.Array]) -> chex.Array: + """Concatenates a obs dictionary that was built for two env agents. + Doubles the number of parallel_envs, i.e. (num_students, n_parallel, ...) -> (num_students, 2*n_parallel, ...) + """ + return jnp.concatenate([obs['agent_0'][:, :, jnp.newaxis, :], obs['agent_1'][:, :, jnp.newaxis, :]], axis=2) diff --git a/src/minimax/runners_ma/eval_runner.py b/src/minimax/runners_ma/eval_runner.py new file mode 100644 index 0000000..9f7a508 --- /dev/null +++ b/src/minimax/runners_ma/eval_runner.py @@ -0,0 +1,371 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from typing import Dict, Tuple, Optional + +import chex +import einops +import numpy as np +import jax +import jax.numpy as jnp + +import minimax.envs as envs +from minimax.util.rl import ( + RollingStats +) +import minimax.util.pytree as _tree_util + + +def generate_all_kwargs_combos(arg_choices): + def update_kwargs_with_choices(prev_combos, k, choices): + updated_combos = [] + for v in choices: + for p in prev_combos: + updated = p.copy() + updated[k] = v + updated_combos.append(updated) + + return updated_combos + + all_combos = [{}] + for k, choices in arg_choices.items(): + all_combos = update_kwargs_with_choices(all_combos, k, choices) + + return all_combos + + +def create_envs_for_kwargs(env_names, kwargs): + # Check for csv kwargs + arg_choices = {} + varied_args = [] + for k, v in kwargs.items(): + if isinstance(v, str) and ',' in v: + vs = eval(v) + arg_choices[k] = vs + varied_args.append(k) + elif isinstance(v, str): + arg_choices[k] = [eval(v)] + else: + arg_choices[k] = [v] + + # List of kwargs + kwargs_combos = generate_all_kwargs_combos(arg_choices) + + env_infos = [] + incl_ext = len(varied_args) > 0 + for name in env_names: + for kwargs in kwargs_combos: + if incl_ext and len(kwargs) > 0: + ext = ':'.join([f'{k}={kwargs[k]}' for k in varied_args]) + ext_name = f'{name}:{ext}' + else: + ext_name = name + env_infos.append( + (name, ext_name, kwargs) + ) + + return env_infos + + +class EvalRunner: + def __init__( + self, + pop, + env_names, + env_kwargs=None, + n_episodes=10, + agent_idxs='*', + render_mode=None): + + self.pop = pop + + if isinstance(agent_idxs, str): + if "*" in agent_idxs: + self.agent_idxs = np.arange(pop.n_agents) + else: + self.agent_idxs = \ + np.array([int(x) for x in agent_idxs.split(',')]) + else: + self.agent_idxs = agent_idxs # assume array + + assert np.max(self.agent_idxs) < pop.n_agents, \ + 'Agent index is out of bounds.' + + if isinstance(env_names, str): + env_names = [ + x.strip() for x in env_names.split(',') + ] + + self.n_episodes = n_episodes + env_infos = create_envs_for_kwargs(env_names, env_kwargs) + env_names = [] + self.ext_env_names = [] + env_kwargs = [] + for (name, ext_name, kwargs) in env_infos: + env_names.append(name) + self.ext_env_names.append(ext_name) + env_kwargs.append(kwargs) + self.n_envs = len(env_names) + + self.benvs = [] + self.env_params = [] + self.env_has_solved_rate = [] + for env_name, kwargs in zip(env_names, env_kwargs): + benv = envs.BatchEnv( + env_name=env_name, + n_parallel=n_episodes, + n_eval=1, + env_kwargs=kwargs, + wrappers=['monitor_return', 'monitor_ep_metrics'] + ) + self.benvs.append(benv) + self.env_params.append(benv.env.params) + self.env_has_solved_rate.append( + benv.env.eval_solved_rate is not None) + + self.action_dtype = self.benvs[0].env.action_space().dtype + + monitored_metrics = self.benvs[0].env.get_monitored_metrics() + self.rolling_stats = RollingStats(names=monitored_metrics, window=1) + self._update_ep_stats = jax.vmap( + jax.vmap( + self.rolling_stats.update_stats, in_axes=(0, 0, 0, None)), + in_axes=(0, 0, 0, None)) + + self.test_return_pre = 'test_return' + self.test_solved_rate_pre = 'test_solved_rate' + + self.render_mode = render_mode + if render_mode: + from minimax.envs.viz.grid_viz import GridVisualizer + self.viz = GridVisualizer() + self.viz.show() + + if render_mode == 'ipython': + from IPython import display + self.ipython_display = display + + def load_checkpoint_state(self, runner_state, state): + runner_state = list(runner_state) + runner_state[1] = runner_state[1].load_state_dict(state[1]) + + return tuple(runner_state) + + def _concat_multi_agent_obs(self, obs: Dict[str, chex.Array]) -> chex.Array: + """Concatenates a obs dictionary that was built for two env agents. + Doubles the number of parallel_envs, i.e. (num_students, n_parallel, ...) -> (num_students, 2*n_parallel, ...) + """ + return jnp.concatenate([obs['agent_0'][:, :, jnp.newaxis, :], obs['agent_1'][:, :, jnp.newaxis, :]], axis=2) + + @partial(jax.jit, static_argnums=(0, 2)) + def _get_transition( + self, + rng, + benv, + actor_params, + state, + obs, + actor_carry, + zero_carry, + done, + extra): + # Sample action + ma_obs = self._concat_multi_agent_obs(obs) + _, pi_params, next_actor_carry = jax.vmap(self.pop.act, in_axes=(None, 2, 2, None))( + actor_params, ma_obs, actor_carry, done) + next_actor_carry = jax.tree_map(lambda x: einops.rearrange( + x, 't n a d -> a t n d'), next_actor_carry) + + pi = self.pop.get_action_dist(pi_params, dtype=self.action_dtype) + rng, subrng = jax.random.split(rng) + action = pi.sample(seed=subrng) + log_pi = pi.log_prob(action) + + env_action = { + 'agent_0': action[0], + 'agent_1': action[1] + } + + rng, *vrngs = jax.random.split(rng, self.pop.n_agents+1) + (next_obs, + next_state, + reward, + done, + info, + extra) = benv.step(jnp.array(vrngs), state, env_action, extra) + + log_pi = einops.rearrange(log_pi, 'a s n -> s n a') + + action = einops.rearrange(action, 'a s n -> s n a') + + done_ = jnp.concatenate( + [done[:, :, jnp.newaxis], done[:, :, jnp.newaxis]], axis=2) + + next_actor_carry = jax.tree_map(lambda x: einops.rearrange( + x, 'n a s d -> s n a d'), next_actor_carry) + step = (ma_obs, action, info["sparse_reward"], + info["shaped_reward"], done_, log_pi) + if actor_carry is not None: + step += (actor_carry,) + + if actor_carry is not None: + next_actor_carry = jax.vmap(_tree_util.pytree_select)( + done, zero_carry, next_actor_carry) + + if self.render_mode: + self.viz.render( + benv.env.params, + jax.tree_util.tree_map(lambda x: x[0][0], state), + highlight=False) + if self.render_mode == 'ipython': + self.ipython_display.display(self.viz.window.fig) + self.ipython_display.clear_output(wait=True) + return ( + next_state, + next_obs, + next_actor_carry, + done, + info, + extra + ) + + @partial(jax.jit, static_argnums=(0, 2)) + def _rollout_benv( + self, + rng, + benv, + params, + env_params, + state, + obs, + carry, + zero_carry, + extra, + done, + ep_stats): + + def _scan_rollout(scan_carry, rng): + (state, + obs, + carry, + extra, + done, + ep_stats) = scan_carry + + step = \ + self._get_transition( + rng, + benv, + params, + state, + obs, + carry, + zero_carry, + done, + extra) + + (next_state, + next_obs, + next_carry, + done, + info, + extra) = step + + ep_stats = self._update_ep_stats(ep_stats, done, info, 1) + + return (next_state, next_obs, next_carry, extra, done, ep_stats), None + + n_steps = benv.env.max_episode_steps() + rngs = jax.random.split(rng, n_steps) + (state, + obs, + carry, + extra, + done, + ep_stats), _ = jax.lax.scan( + _scan_rollout, + (state, obs, carry, extra, done, ep_stats), + rngs, + length=n_steps) + + return ep_stats + + @partial(jax.jit, static_argnums=(0,)) + def run(self, rng, params): + """ + Rollout agents on each env. + + For each env, run n_eval episodes in parallel, + where each is indexed to return in order. + """ + eval_stats = self.fake_run(rng, params) + rng, *rollout_rngs = jax.random.split(rng, self.n_envs+1) + for i, (benv, env_param) in enumerate(zip(self.benvs, self.env_params)): + rng, *reset_rngs = jax.random.split(rng, self.pop.n_agents+1) + obs, state, extra = benv.reset(jnp.array(reset_rngs)) + + if self.pop.agent.is_recurrent: + rng, subrng = jax.random.split(rng) + dummy_obs = self._concat_multi_agent_obs(obs) + actor_zero_carry, _ = self.pop.init_carry(subrng, dummy_obs) + else: + actor_zero_carry = None + + # Reset episodic stats + ep_stats = self.rolling_stats.reset_stats( + batch_shape=(self.pop.n_agents, self.n_episodes)) + + done = jnp.zeros( + (self.pop.n_agents, self.n_episodes), dtype=jnp.bool_) + + ep_stats = self._rollout_benv( + rollout_rngs[i], + benv, + jax.lax.stop_gradient(params), + env_param, + state, + obs, + actor_zero_carry, + actor_zero_carry, + extra, + done, + ep_stats) + + env_name = self.ext_env_names[i] + mean_return = ep_stats['return'].mean(1) + std_return = ep_stats['return'].std(1) + + if self.env_has_solved_rate[i]: + mean_solved_rate = jax.vmap( + jax.vmap(benv.env.eval_solved_rate))(ep_stats).mean(1) + + for idx in self.agent_idxs: + eval_stats[f'eval/a{idx}:{self.test_return_pre}:{env_name}'] = mean_return[idx].squeeze() + eval_stats[f'eval/a{idx}:{self.test_return_pre}_std:{env_name}'] = std_return[idx].squeeze() + if self.env_has_solved_rate[i]: + eval_stats[f'eval/a{idx}:{self.test_solved_rate_pre}:{env_name}'] = mean_solved_rate[idx].squeeze() + + return eval_stats + + def fake_run(self, rng, params): + eval_stats = {} + for i, env_name in enumerate(self.ext_env_names): + for idx in self.agent_idxs: + eval_stats.update({ + f'eval/a{idx}:{self.test_return_pre}:{env_name}': 0. + }) + eval_stats.update({ + f'eval/a{idx}:{self.test_return_pre}_std:{env_name}': 0. + }) + if self.env_has_solved_rate[i]: + eval_stats.update({ + f'eval/a{idx}:{self.test_solved_rate_pre}:{env_name}': 0., + }) + + return eval_stats diff --git a/src/minimax/runners_ma/eval_runner_heterogenous.py b/src/minimax/runners_ma/eval_runner_heterogenous.py new file mode 100644 index 0000000..08b735d --- /dev/null +++ b/src/minimax/runners_ma/eval_runner_heterogenous.py @@ -0,0 +1,388 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from typing import Dict + +import chex +import numpy as np +import jax +import jax.numpy as jnp + +import minimax.envs as envs +from minimax.util.rl import ( + AgentPop, + RollingStats +) +import minimax.util.pytree as _tree_util + + +def generate_all_kwargs_combos(arg_choices): + def update_kwargs_with_choices(prev_combos, k, choices): + updated_combos = [] + for v in choices: + for p in prev_combos: + updated = p.copy() + updated[k] = v + updated_combos.append(updated) + + return updated_combos + + all_combos = [{}] + for k, choices in arg_choices.items(): + all_combos = update_kwargs_with_choices(all_combos, k, choices) + + return all_combos + + +def create_envs_for_kwargs(env_names, kwargs): + # Check for csv kwargs + arg_choices = {} + varied_args = [] + for k, v in kwargs.items(): + if isinstance(v, str) and ',' in v: + vs = eval(v) + arg_choices[k] = vs + varied_args.append(k) + elif isinstance(v, str): + arg_choices[k] = [eval(v)] + else: + arg_choices[k] = [v] + + # List of kwargs + kwargs_combos = generate_all_kwargs_combos(arg_choices) + + env_infos = [] + incl_ext = len(varied_args) > 0 + for name in env_names: + for kwargs in kwargs_combos: + if incl_ext and len(kwargs) > 0: + ext = ':'.join([f'{k}={kwargs[k]}' for k in varied_args]) + ext_name = f'{name}:{ext}' + else: + ext_name = name + env_infos.append( + (name, ext_name, kwargs) + ) + + return env_infos + + +class EvalRunnerHeterogenous: + def __init__( + self, + pop, + env_names, + env_kwargs=None, + n_episodes=10, + agent_idxs='*', + render_mode=None): + + self.pop = pop + + if isinstance(agent_idxs, str): + if "*" in agent_idxs: + self.agent_idxs = np.arange(pop.n_agents) + else: + self.agent_idxs = \ + np.array([int(x) for x in agent_idxs.split(',')]) + else: + self.agent_idxs = agent_idxs # assume array + + assert np.max(self.agent_idxs) < pop.n_agents, \ + 'Agent index is out of bounds.' + + if isinstance(env_names, str): + env_names = [ + x.strip() for x in env_names.split(',') + ] + + self.n_episodes = n_episodes + env_infos = create_envs_for_kwargs(env_names, env_kwargs) + env_names = [] + self.ext_env_names = [] + env_kwargs = [] + for (name, ext_name, kwargs) in env_infos: + env_names.append(name) + self.ext_env_names.append(ext_name) + env_kwargs.append(kwargs) + self.n_envs = len(env_names) + + self.benvs = [] + self.env_params = [] + self.env_has_solved_rate = [] + for env_name, kwargs in zip(env_names, env_kwargs): + benv = envs.BatchEnv( + env_name=env_name, + n_parallel=n_episodes, + n_eval=1, + env_kwargs=kwargs, + wrappers=['monitor_return', 'monitor_ep_metrics'] + ) + self.benvs.append(benv) + self.env_params.append(benv.env.params) + self.env_has_solved_rate.append( + benv.env.eval_solved_rate is not None) + + self.action_dtype = self.benvs[0].env.action_space().dtype + + monitored_metrics = self.benvs[0].env.get_monitored_metrics() + self.rolling_stats = RollingStats(names=monitored_metrics, window=1) + self._update_ep_stats = jax.vmap( + jax.vmap( + self.rolling_stats.update_stats, in_axes=(0, 0, 0, None)), + in_axes=(0, 0, 0, None)) + + self.test_return_pre = 'test_return' + self.test_solved_rate_pre = 'test_solved_rate' + + self.render_mode = render_mode + if render_mode: + from minimax.envs.viz.grid_viz import GridVisualizer + self.viz = GridVisualizer() + self.viz.show() + + if render_mode == 'ipython': + from IPython import display + self.ipython_display = display + + def load_checkpoint_state(self, runner_state, state): + runner_state = list(runner_state) + runner_state[1] = runner_state[1].load_state_dict(state[1]) + + return tuple(runner_state) + + def _concat_multi_agent_obs(self, obs: Dict[str, chex.Array]) -> chex.Array: + """Concatenates a obs dictionary that was built for two env agents. + Doubles the number of parallel_envs, i.e. (num_students, n_parallel, ...) -> (num_students, 2*n_parallel, ...) + """ + return jnp.concatenate([obs['agent_0'][:, :, jnp.newaxis, :], obs['agent_1'][:, :, jnp.newaxis, :]], axis=2) + + @partial(jax.jit, static_argnums=(0, 2)) + def _get_transition( + self, + rng, + benv, + actor_0_params, + actor_1_params, + state, + obs, + actor_0_carry, + actor_1_carry, + zero_0_carry, + zero_1_carry, + done, + extra): + _, _, pi_0_params, pi_1_params, next_actor_0_carry, next_actor_1_carry = self.pop.act( + (actor_0_params, actor_1_params), obs, (actor_0_carry, actor_1_carry), done) + + pi_0 = self.pop.get_action_0_dist(pi_0_params, dtype=self.action_dtype) + pi_1 = self.pop.get_action_1_dist(pi_1_params, dtype=self.action_dtype) + rng, subrng = jax.random.split(rng) + action_0 = pi_0.sample(seed=subrng) + log_pi_0 = pi_0.log_prob(action_0) + + rng, subrng = jax.random.split(rng) + action_1 = pi_1.sample(seed=subrng) + log_pi_1 = pi_1.log_prob(action_1) + + env_action = { + 'agent_0': action_0, + 'agent_1': action_1 + } + + rng, *vrngs = jax.random.split(rng, self.pop.n_agents+1) + (next_obs, + next_state, + reward, + done, + info, + extra) = benv.step(jnp.array(vrngs), state, env_action, extra) + + done_ = jnp.concatenate( + [done[:, :, jnp.newaxis], done[:, :, jnp.newaxis]], axis=2) + + if actor_0_carry is not None: + next_actor_0_carry = jax.vmap(_tree_util.pytree_select)( + done, zero_0_carry, next_actor_0_carry) + + if actor_1_carry is not None: + next_actor_1_carry = jax.vmap(_tree_util.pytree_select)( + done, zero_1_carry, next_actor_1_carry) + + if self.render_mode: + self.viz.render( + benv.env.params, + jax.tree_util.tree_map(lambda x: x[0][0], state)) + if self.render_mode == 'ipython': + self.ipython_display.display(self.viz.window.fig) + self.ipython_display.clear_output(wait=True) + return ( + next_state, + next_obs, + next_actor_0_carry, + next_actor_1_carry, + done, + info, + extra + ) + + @partial(jax.jit, static_argnums=(0, 2)) + def _rollout_benv( + self, + rng, + benv, + params_0, + params_1, + env_params, + state, + obs, + carry_0, + carry_1, + zero_0_carry, + zero_1_carry, + extra, + done, + ep_stats): + + def _scan_rollout(scan_carry, rng): + (state, + obs, + carry_0, + carry_1, + extra, + done, + ep_stats) = scan_carry + + step = \ + self._get_transition( + rng, + benv, + params_0, + params_1, + state, + obs, + carry_0, + carry_1, + zero_0_carry, + zero_1_carry, + done, + extra) + + (next_state, + next_obs, + next_0_carry, + next_1_carry, + done, + info, + extra) = step + + ep_stats = self._update_ep_stats(ep_stats, done, info, 1) + + return (next_state, next_obs, next_0_carry, next_1_carry, extra, done, ep_stats), None + + n_steps = benv.env.max_episode_steps() + rngs = jax.random.split(rng, n_steps) + (state, + obs, + carry_0, + carry_1, + extra, + done, + ep_stats), _ = jax.lax.scan( + _scan_rollout, + (state, obs, carry_0, carry_1, extra, done, ep_stats), + rngs, + length=n_steps) + + return ep_stats + + @partial(jax.jit, static_argnums=(0,)) + def run(self, rng, params_0, params_1): + """ + Rollout agents on each env. + + For each env, run n_eval episodes in parallel, + where each is indexed to return in order. + """ + eval_stats = self.fake_run( + rng, params_0) # Params do not matter for the fake run + rng, *rollout_rngs = jax.random.split(rng, self.n_envs+1) + for i, (benv, env_param) in enumerate(zip(self.benvs, self.env_params)): + rng, *reset_rngs = jax.random.split(rng, self.pop.n_agents+1) + obs, state, extra = benv.reset(jnp.array(reset_rngs)) + + if self.pop.agent_0.is_recurrent: + rng, subrng = jax.random.split(rng) + actor_0_zero_carry, _ = self.pop.init_carry_agent_0( + subrng, obs['agent_0']) + else: + actor_0_zero_carry = None + + if self.pop.agent_1.is_recurrent: + rng, subrng = jax.random.split(rng) + actor_1_zero_carry, _ = self.pop.init_carry_agent_1( + subrng, obs['agent_1']) + else: + actor_1_zero_carry = None + + # Reset episodic stats + ep_stats = self.rolling_stats.reset_stats( + batch_shape=(self.pop.n_agents, self.n_episodes)) + + done = jnp.zeros( + (self.pop.n_agents, self.n_episodes), dtype=jnp.bool_) + + ep_stats = self._rollout_benv( + rollout_rngs[i], + benv, + jax.lax.stop_gradient(params_0), + jax.lax.stop_gradient(params_1), + env_param, + state, + obs, + actor_0_zero_carry, + actor_1_zero_carry, + actor_0_zero_carry, + actor_1_zero_carry, + extra, + done, + ep_stats) + + env_name = self.ext_env_names[i] + mean_return = ep_stats['return'].mean(1) + std_return = ep_stats['return'].std(1) + + if self.env_has_solved_rate[i]: + mean_solved_rate = jax.vmap( + jax.vmap(benv.env.eval_solved_rate))(ep_stats).mean(1) + + for idx in self.agent_idxs: + eval_stats[f'eval/a{idx}:{self.test_return_pre}:{env_name}'] = mean_return[idx].squeeze() + eval_stats[f'eval/a{idx}:{self.test_return_pre}_std:{env_name}'] = std_return[idx].squeeze() + if self.env_has_solved_rate[i]: + eval_stats[f'eval/a{idx}:{self.test_solved_rate_pre}:{env_name}'] = mean_solved_rate[idx].squeeze() + + return eval_stats + + def fake_run(self, rng, params): + eval_stats = {} + for i, env_name in enumerate(self.ext_env_names): + for idx in self.agent_idxs: + eval_stats.update({ + f'eval/a{idx}:{self.test_return_pre}:{env_name}': 0. + }) + eval_stats.update({ + f'eval/a{idx}:{self.test_return_pre}_std:{env_name}': 0. + }) + if self.env_has_solved_rate[i]: + eval_stats.update({ + f'eval/a{idx}:{self.test_solved_rate_pre}:{env_name}': 0., + }) + + return eval_stats diff --git a/src/minimax/runners_ma/paired_runner.py b/src/minimax/runners_ma/paired_runner.py new file mode 100644 index 0000000..79d1ff5 --- /dev/null +++ b/src/minimax/runners_ma/paired_runner.py @@ -0,0 +1,818 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from enum import Enum +from functools import partial +from typing import Dict, Tuple, Optional +import inspect + +import chex +import einops +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec as P +import optax +import flax +import flax.linen as nn +from flax.core.frozen_dict import FrozenDict + +import minimax.envs as envs +from minimax.util import pytree as _tree_util +from minimax.util.rl import ( + AgentPop, + VmapTrainState, + VmapMAPPOTrainState, + RolloutStorage, + RolloutStorageSeperate, + RollingStats, + UEDScore, + compute_ued_scores +) + + +class PAIREDRunner: + """ + Orchestrates rollouts across one or more students and teachers. + The main components at play: + - AgentPop: Manages train state and batched inference logic + for a population of agents. + - BatchUEDEnv: Manages environment step and reset logic for a + population of agents batched over a pair of student and + teacher MDPs. + - RolloutStorage: Manages the storing and sampling of collected txns. + - PPO: Handles PPO updates, which take a train state + batch of txns. + """ + + def __init__( + self, + env_name, + env_kwargs, + ued_env_kwargs, + student_agents, + student_agent_kind, + n_students=2, + n_parallel=1, + n_eval=1, + n_rollout_steps=250, + lr=1e-4, + lr_final=None, + lr_anneal_steps=0, + max_grad_norm=0.5, + discount=0.99, + gae_lambda=0.95, + adam_eps=1e-5, + teacher_lr=None, + teacher_lr_final=None, + teacher_lr_anneal_steps=None, + teacher_discount=0.99, + teacher_gae_lambda=0.95, + teacher_agents=None, + ued_score='relative_regret', + track_env_metrics=False, + n_unroll_rollout=1, + render=False, + n_devices=1, + shaped_reward=False, + ): + assert n_parallel % n_devices == 0, 'Num envs must be divisible by num devices.' + + ued_score = UEDScore[ued_score.upper()] + + assert len(student_agents) == 1, \ + 'Only one type of student supported.' + assert not (n_students > 2 and ued_score in [UEDScore.RELATIVE_REGRET, UEDScore.MEAN_RELATIVE_REGRET]), \ + 'Standard PAIRED uses only 2 students.' + assert teacher_agents is None or len(teacher_agents) == 1, \ + 'Only one type of teacher supported.' + + self.student_agent_kind = student_agent_kind + self.n_students = n_students + self.n_parallel = n_parallel // n_devices + self.n_eval = n_eval + self.n_devices = n_devices + self.step_batch_size = n_students*n_eval*n_parallel + self.n_rollout_steps = n_rollout_steps + self.n_updates = 0 + self.lr = lr + self.lr_final = lr if lr_final is None else lr_final + self.lr_anneal_steps = lr_anneal_steps + self.teacher_lr = \ + lr if teacher_lr is None else lr + self.teacher_lr_final = \ + self.lr_final if teacher_lr_final is None else teacher_lr_final + self.teacher_lr_anneal_steps = \ + lr_anneal_steps if teacher_lr_anneal_steps is None else teacher_lr_anneal_steps + self.max_grad_norm = max_grad_norm + self.adam_eps = adam_eps + self.ued_score = ued_score + self.track_env_metrics = track_env_metrics + + self.shaped_reward = shaped_reward + + self.n_unroll_rollout = n_unroll_rollout + self.render = render + + self.student_pop = AgentPop(student_agents[0], n_agents=n_students) + + if teacher_agents is not None: + self.teacher_pop = AgentPop(teacher_agents[0], n_agents=1) + + # This ensures correct partial-episodic bootstrapping by avoiding + # any termination purely due to timeouts. + # env_kwargs.max_episode_steps = self.n_rollout_steps + 1 + + wrappers_lst = ['monitor_return', 'monitor_ep_metrics'] + if self.student_agent_kind == "mappo": + wrappers_lst = ['world_state_wrapper'] + wrappers_lst + + self.benv = envs.BatchUEDEnv( + env_name=env_name, + n_parallel=self.n_parallel, + n_eval=n_eval, + env_kwargs=env_kwargs, + ued_env_kwargs=ued_env_kwargs, + wrappers=wrappers_lst, + ued_wrappers=[] + ) + self.action_dtype = self.benv.env.action_space().dtype + + self.teacher_n_rollout_steps = \ + self.benv.env.ued_max_episode_steps() + + self.student_rollout = RolloutStorageSeperate( + discount=discount, + gae_lambda=gae_lambda, + n_steps=n_rollout_steps, + n_agents=n_students, + n_envs=self.n_parallel, + n_eval=self.n_eval, + action_space=self.benv.env.action_space(), + obs_space=self.benv.env.observation_space(), + obs_space_shared_shape=self.benv.env.world_state_size(), + agent=self.student_pop.agent + ) + + self.teacher_rollout = RolloutStorage( + discount=teacher_discount, + gae_lambda=teacher_gae_lambda, + n_steps=self.teacher_n_rollout_steps, + n_agents=1, + n_envs=self.n_parallel, + n_eval=1, + action_space=self.benv.env.ued_action_space(), + obs_space=self.benv.env.ued_observation_space(), + agent=self.teacher_pop.agent, + ) + + ued_monitored_metrics = ('return',) + self.ued_rolling_stats = RollingStats( + names=ued_monitored_metrics, + window=10, + ) + + monitored_metrics = self.benv.env.get_monitored_metrics() + self.rolling_stats = RollingStats( + names=monitored_metrics, + window=10, + ) + + self._update_ep_stats = jax.vmap( + jax.vmap(self.rolling_stats.update_stats)) + self._update_ued_ep_stats = jax.vmap( + jax.vmap(self.ued_rolling_stats.update_stats)) + + if self.render: + from envs.viz.grid_viz import GridVisualizer + self.viz = GridVisualizer() + self.viz.show() + + def reset(self, rng): + self.n_updates = 0 + + n_parallel = self.n_parallel*self.n_devices + + rng, student_rng, teacher_rng = jax.random.split(rng, 3) + student_info = self._reset_pop( + student_rng, + self.student_pop, + partial(self.benv.reset, sub_batch_size=n_parallel*self.n_eval), + n_parallel_ep=n_parallel*self.n_eval, + lr_init=self.lr, + lr_final=self.lr_final, + lr_anneal_steps=self.lr_anneal_steps) + + teacher_info = self._reset_teacher_pop( + teacher_rng, + self.teacher_pop, + partial(self.benv.reset_teacher, n_parallel=n_parallel), + n_parallel_ep=n_parallel, + lr_init=self.teacher_lr, + lr_final=self.teacher_lr_final, + lr_anneal_steps=self.teacher_lr_anneal_steps) + + return ( + rng, + *student_info, + *teacher_info + ) + + def _reset_pop( + self, + rng, + pop, + env_reset_fn, + n_parallel_ep=1, + lr_init=3e-4, + lr_final=3e-4, + lr_anneal_steps=0): + rng, *vrngs = jax.random.split(rng, pop.n_agents+1) + reset_out = env_reset_fn(jnp.array(vrngs)) + if len(reset_out) == 2: + obs, state = reset_out + else: + obs, state, extra = reset_out + + n_parallel = self.n_parallel*self.n_devices + + # dummy_obs = jax.tree_util.tree_map(lambda x: x[0], obs) # for one agent only + dummy_obs = self._concat_multi_agent_obs(obs) + dummy_shared_obs = self._concat_multi_agent_obs(obs['world_state']) + + rng, subrng = jax.random.split(rng) + if self.student_pop.agent.is_recurrent: + actor_carry, critic_carry = self.student_pop.init_carry( + subrng, dummy_obs) + # Technically returns actor and critic carry but we only need one + self.zero_carry = jax.tree_map( + lambda x: x.at[:, :self.n_parallel].get(), actor_carry) + else: + actor_carry, critic_carry = None, None + + rng, subrng = jax.random.split(rng) + actor_params, critic_params = self.student_pop.init_params( + subrng, (dummy_obs[0], dummy_shared_obs[0])) + + schedule_fn = optax.linear_schedule( + init_value=-float(lr_init), + end_value=-float(lr_final), + transition_steps=lr_anneal_steps, + ) + + tx_actor = optax.chain( + optax.clip_by_global_norm(self.max_grad_norm), + optax.scale_by_adam(eps=self.adam_eps), + optax.scale_by_schedule(schedule_fn), + ) + + tx_critic = optax.chain( + optax.clip_by_global_norm(self.max_grad_norm), + optax.scale_by_adam(eps=self.adam_eps), + optax.scale_by_schedule(schedule_fn), + ) + + shaped_reward_coeff_value = 1.0 if self.shaped_reward else 0.0 + shaped_reward_coeff = jnp.full( + (self.n_students, 1), fill_value=shaped_reward_coeff_value) + train_state = VmapMAPPOTrainState.create( + actor_apply_fn=self.student_pop.agent.evaluate_action, + actor_params=actor_params, + actor_tx=tx_actor, + critic_apply_fn=self.student_pop.agent.get_value, + critic_params=critic_params, + critic_tx=tx_critic, + shaped_reward_coeff=shaped_reward_coeff, + ) + + ep_stats = self.rolling_stats.reset_stats( + batch_shape=(self.n_students, n_parallel*self.n_eval)) + + start_state = state + + ep_stats = self.rolling_stats.reset_stats( + batch_shape=(pop.n_agents, n_parallel_ep)) + + return train_state, state, obs, actor_carry, critic_carry, ep_stats + + def _reset_teacher_pop( + self, + rng, + pop, + env_reset_fn, + n_parallel_ep=1, + lr_init=3e-4, + lr_final=3e-4, + lr_anneal_steps=0): + rng, *vrngs = jax.random.split(rng, pop.n_agents+1) + reset_out = env_reset_fn(jnp.array(vrngs)) + if len(reset_out) == 2: + obs, state = reset_out + else: + obs, state, extra = reset_out + dummy_obs = jax.tree_util.tree_map( + lambda x: x[0], obs) # for one agent only + + rng, subrng = jax.random.split(rng) + if pop.agent.is_recurrent: + carry = pop.init_carry(subrng, obs) + else: + carry = None + + rng, subrng = jax.random.split(rng) + params = pop.init_params(subrng, dummy_obs) + + schedule_fn = optax.linear_schedule( + init_value=-float(lr_init), + end_value=-float(lr_final), + transition_steps=lr_anneal_steps, + ) + + tx = optax.chain( + optax.clip_by_global_norm(self.max_grad_norm), + optax.scale_by_adam(eps=self.adam_eps), + optax.scale_by_schedule(schedule_fn), + ) + + train_state = VmapTrainState.create( + apply_fn=pop.agent.evaluate, + params=params, + tx=tx + ) + + ep_stats = self.rolling_stats.reset_stats( + batch_shape=(pop.n_agents, n_parallel_ep)) + + return train_state, state, obs, carry, ep_stats + + def get_checkpoint_state(self, state): + _state = list(state) + _state[1] = state[1].state_dict + _state[7] = state[7].state_dict + + return _state + + def load_checkpoint_state(self, runner_state, state): + runner_state = list(runner_state) + runner_state[1] = runner_state[1].load_state_dict(state[1]) + runner_state[7] = runner_state[7].load_state_dict(state[7]) + + return tuple(runner_state) + + @partial(jax.jit, static_argnums=(0, 2,)) + def _get_ma_transition( + self, + rng, + pop, + actor_params, + critic_params, + obs, + actor_carry, + critic_carry, + done + ): + ma_obs = self._concat_multi_agent_obs(obs) + _, pi_params, next_actor_carry = jax.vmap(pop.act, in_axes=(None, 2, 2, None))( + actor_params, ma_obs, actor_carry, done) + next_actor_carry = jax.tree_map(lambda x: einops.rearrange( + x, 't n a d -> a t n d'), next_actor_carry) + shared_obs = self._concat_multi_agent_obs(obs['world_state']) + value, next_critic_carry = jax.vmap(pop.get_value, in_axes=(None, 2, 2, None))( + critic_params, shared_obs, critic_carry, done) + next_critic_carry = jax.tree_map(lambda x: einops.rearrange( + x, 't n a d -> a t n d'), next_critic_carry) + + pi = pop.get_action_dist(pi_params, dtype=self.action_dtype) + rng, subrng = jax.random.split(rng) + action = pi.sample(seed=subrng) + log_pi = pi.log_prob(action) + + env_action = { + 'agent_0': action[0], + 'agent_1': action[1] + } + + # # Add transition to storage + log_pi = einops.rearrange(log_pi, 'a s n -> s n a') + value = einops.rearrange(value, 'a s n -> s n a') + action = einops.rearrange(action, 'a s n -> s n a') + + return ( + shared_obs, + value, + log_pi, + env_action, + action, + (jax.tree_map(lambda x: einops.rearrange( + x, 'n a s d -> s n a d'), next_actor_carry), + jax.tree_map(lambda x: einops.rearrange( + x, 'n a s d -> s n a d'), next_critic_carry)), + ) + + @partial(jax.jit, static_argnums=(0, 2, 3)) + def _get_transition( + self, + rng, + pop, + rollout_mgr, + rollout, + params, + state, + obs, + carry, + done, + reset_state=None, + extra=None): + # Sample action + if type(params) == tuple and len(params) == 2: + actor_carry, critic_carry = carry + actor_params, critic_params = params + shared_obs, value, log_pi, env_action, action, next_carry = self._get_ma_transition( + rng, + pop, + actor_params, + critic_params, + obs, + actor_carry, + critic_carry, + done + ) + is_multi_agent = True + else: + value, pi_params, next_carry = pop.act(params, obs, carry, done) + pi = pop.get_action_dist(pi_params, dtype=self.action_dtype) + rng, subrng = jax.random.split(rng) + env_action = pi.sample(seed=subrng) + action = env_action # Is the same in single agent case but a dict in multi_agent + log_pi = pi.log_prob(action) + is_multi_agent = False + shared_obs = None + + rng, *vrngs = jax.random.split(rng, pop.n_agents+1) + + if pop is self.student_pop: + step_fn = self.benv.step_student + else: + step_fn = self.benv.step_teacher + step_args = (jnp.array(vrngs), state, env_action) + + if reset_state is not None: # Needed for student to reset to same instance + step_args += (reset_state,) + + if extra is not None: + step_args += (extra,) + next_obs, next_state, reward, done, info, extra = step_fn( + *step_args) + else: + next_obs, next_state, reward, done, info = step_fn(*step_args) + + if is_multi_agent: + obs = self._concat_multi_agent_obs(obs) + + # Add transition to storage + if shared_obs is not None: + done_ = jnp.concatenate( + [done[:, :, jnp.newaxis], done[:, :, jnp.newaxis]], axis=2) + # jax.debug.print("info r {i}, info sr {s}, r {r}", i=jnp.sum( + # info["sparse_reward"]), s=jnp.sum(info["shaped_reward"]), r=jnp.sum(reward)) + + step = (obs, shared_obs, action, + info["sparse_reward"], info["shaped_reward"], done_, log_pi, value) + else: + step = (obs, action, reward, done, log_pi, value) + + if is_multi_agent: + if carry[0] is not None: + step += (carry[0], carry[1]) # Actor and Critic + else: + if carry is not None: + step += (carry,) + + rollout = rollout_mgr.append(rollout, *step) + + if self.render and pop is self.student_pop: + self.viz.render( + self.benv.env.env.params, + jax.tree_util.tree_map(lambda x: x[0][0], state)) + + return rollout, next_state, next_obs, next_carry, done, info, extra + + @partial(jax.jit, static_argnums=(0, 2, 3, 4)) + def _rollout( + self, + rng, + pop, + rollout_mgr, + n_steps, + params, + state, + obs, + carry, + done, + reset_state=None, + extra=None, + ep_stats=None): + rngs = jax.random.split(rng, n_steps) + + rollout = rollout_mgr.reset() + + def _scan_rollout(scan_carry, rng): + (rollout, + state, + obs, + carry, + done, + extra, + ep_stats) = scan_carry + + next_scan_carry = \ + self._get_transition( + rng, + pop, + rollout_mgr, + rollout, + params, + state, + obs, + carry, + done, + reset_state, + extra) + + (rollout, + next_state, + next_obs, + next_carry, + done, + info, + extra) = next_scan_carry + + if ep_stats is not None: + _ep_stats_update_fn = self._update_ep_stats \ + if pop is self.student_pop else self._update_ued_ep_stats + + ep_stats = _ep_stats_update_fn(ep_stats, done, info) + + return (rollout, next_state, next_obs, next_carry, done, extra, ep_stats), None + + (rollout, state, obs, carry, done, extra, ep_stats), _ = jax.lax.scan( + _scan_rollout, + (rollout, state, obs, carry, done, extra, ep_stats), + rngs, + length=n_steps, + unroll=self.n_unroll_rollout + ) + + return rollout, state, obs, carry, extra, ep_stats + + @partial(jax.jit, static_argnums=(0,)) + def _compile_stats(self, + update_stats, ep_stats, + ued_update_stats, ued_ep_stats, + env_metrics=None, + grad_stats=None, + ued_grad_stats=None, + shaped_reward_coeff=None + ): + mean_returns_by_student = jax.vmap( + lambda x: x.mean())(ep_stats['return']) + mean_returns_by_teacher = jax.vmap( + lambda x: x.mean())(ued_ep_stats['return']) + + mean_ep_stats = jax.vmap(lambda info: jax.tree_map(lambda x: x.mean(), info))( + {k: ep_stats[k] for k in self.rolling_stats.names} + ) + ued_mean_ep_stats = jax.vmap(lambda info: jax.tree_map(lambda x: x.mean(), info))( + {k: ued_ep_stats[k] for k in self.ued_rolling_stats.names} + ) + + student_stats = { + f'mean_{k}': v for k, v in mean_ep_stats.items() + } + student_stats.update(update_stats) + + stats = {} + + if shaped_reward_coeff is not None: + stats.update( + {"shaped_reward_coeff": shaped_reward_coeff}) + + for i in range(self.n_students): + _student_stats = jax.tree_util.tree_map( + lambda x: x[i], student_stats) # for agent0 + stats.update({f'{k}_a{i}': v for k, v in _student_stats.items()}) + + teacher_stats = { + f'mean_{k}_tch': v for k, v in ued_mean_ep_stats.items() + } + teacher_stats.update({ + f'{k}_tch': v[0] for k, v in ued_update_stats.items() + }) + stats.update(teacher_stats) + + if self.n_devices > 1: + stats = jax.tree_map(lambda x: jax.lax.pmean(x, 'device'), stats) + + return stats + + def get_shmap_spec(self): + runner_state_size = len(inspect.signature(self.run).parameters) + in_spec = [P(None, 'device'),]*(runner_state_size) + out_spec = [P(None, 'device'),]*(runner_state_size) + + in_spec[:2] = [P(None),]*2 + in_spec[6] = P(None) + in_spec = tuple(in_spec) + out_spec = (P(None),) + in_spec + + return in_spec, out_spec + + @partial(jax.jit, static_argnums=(0,)) + def run( + self, + rng, + train_state, + state, + obs, + actor_carry, + critic_carry, + ep_stats, + ued_train_state, + ued_state, + ued_obs, + ued_carry, + ued_ep_stats): + """ + Perform one update step: rollout teacher + students + """ + if self.n_devices > 1: + rng = jax.random.fold_in(rng, jax.lax.axis_index('device')) + + # === Reset teacher env + rollout teacher + rng, *vrngs = jax.random.split(rng, self.teacher_pop.n_agents+1) + ued_reset_out = self.benv.reset_teacher(jnp.array(vrngs)) + if len(ued_reset_out) > 2: + ued_obs, ued_state, ued_extra = ued_reset_out + else: + ued_obs, ued_state = ued_reset_out + ued_extra = None + + # Reset UED ep_stats + if self.ued_rolling_stats is not None: + ued_ep_stats = self.ued_rolling_stats.reset_stats( + batch_shape=(1, self.n_parallel)) + else: + ued_ep_stats = None + + tch_rollout_batch_shape = (1, self.n_parallel*self.n_eval) + done = jnp.zeros(tch_rollout_batch_shape, dtype=jnp.bool_) + rng, subrng = jax.random.split(rng) + ued_rollout, ued_state, ued_obs, ued_carry, _, ued_ep_stats = \ + self._rollout( + subrng, + self.teacher_pop, + self.teacher_rollout, + self.teacher_n_rollout_steps, + jax.lax.stop_gradient(ued_train_state.params), + ued_state, + ued_obs, + ued_carry, + done, + extra=ued_extra, + ep_stats=ued_ep_stats + ) + + # === Reset student to new envs + rollout students + rng, *vrngs = jax.random.split(rng, self.teacher_pop.n_agents+1) + obs, state, extra = jax.tree_util.tree_map( + lambda x: x.squeeze(0), self.benv.reset_student( + jnp.array(vrngs), + ued_state, + self.student_pop.n_agents)) + reset_state = state + + # Reset student ep_stats + st_rollout_batch_shape = (self.n_students, self.n_parallel*self.n_eval) + ep_stats = self.rolling_stats.reset_stats( + batch_shape=st_rollout_batch_shape) + + done = jnp.zeros(st_rollout_batch_shape, dtype=jnp.bool_) + rng, subrng = jax.random.split(rng) + rollout, state, obs, carry, extra, ep_stats = \ + self._rollout( + subrng, + self.student_pop, + self.student_rollout, + self.n_rollout_steps, + (jax.lax.stop_gradient(train_state.actor_params), + jax.lax.stop_gradient(train_state.critic_params)), + state, + obs, + (actor_carry, critic_carry), + done, + reset_state=reset_state, + extra=extra, + ep_stats=ep_stats) + + reward = rollout["rewards"].sum(axis=1).mean(-1)[:, :, jnp.newaxis] + shaped_reward = rollout["shaped_rewards"].sum( + axis=1).mean(-1)[:, :, jnp.newaxis] + + ep_stats["reward"] = reward + ep_stats["shaped_reward"] = shaped_reward + ep_stats["shaped_reward_scaled_by_shaped_reward_coeff"] = shaped_reward * \ + train_state.shaped_reward_coeff.mean() + ep_stats["reward_p_shaped_reward_scaled"] = reward + shaped_reward * \ + train_state.shaped_reward_coeff.mean() + + # === Update student with PPO + # PPOAgent vmaps over the train state and batch. Batch must be N x EM + _, critic_carry = carry + shared_obs = self._concat_multi_agent_obs(obs['world_state']) + value, _ = jax.vmap(self.student_pop.get_value, in_axes=(None, 2, 2))( + jax.lax.stop_gradient(train_state.critic_params), + shared_obs, + critic_carry + ) + + jax.debug.print( + "train_state.shaped_reward_coeff {s}", s=train_state.shaped_reward_coeff) + + value = einops.rearrange( + value, "n_env_agents n_students n_parallel -> n_students n_parallel n_env_agents") + train_batch = self.student_rollout.get_batch( + rollout, + value, + train_state.shaped_reward_coeff + ) + + rng, subrng = jax.random.split(rng) + train_state, update_stats = self.student_pop.update( + subrng, train_state, train_batch) + + # === Update teacher with PPO + # - Compute returns per env per agent + # - Compute batched returns based on returns per env per agent + ued_score, _ = compute_ued_scores( + self.ued_score, train_batch, self.n_eval) + ued_rollout = self.teacher_rollout.set_final_reward( + ued_rollout, ued_score) + ued_train_batch = self.teacher_rollout.get_batch( + ued_rollout, + jnp.zeros((1, self.n_parallel)) # Last step terminates episode + ) + + ued_ep_stats = self._update_ued_ep_stats( + ued_ep_stats, + jnp.ones((1, len(ued_score), 1), dtype=jnp.bool_), + {'return': jnp.expand_dims(ued_score, (0, -1))} + ) + + # Update teacher, batch must be 1 x Ex1 + rng, subrng = jax.random.split(rng) + ued_train_state, ued_update_stats = self.teacher_pop.update( + subrng, ued_train_state, ued_train_batch) + + # -------------------------------------------------- + # Collect metrics + if self.track_env_metrics: + env_metrics = self.benv.get_env_metrics(reset_state) + else: + env_metrics = None + + grad_stats, ued_grad_stats = None, None + + stats = self._compile_stats( + update_stats, ep_stats, + ued_update_stats, ued_ep_stats, + env_metrics, + grad_stats, ued_grad_stats, + shaped_reward_coeff=train_state.shaped_reward_coeff[0]) + stats.update(dict(n_updates=train_state.n_updates[0])) + + train_state = train_state.increment() + ued_train_state = ued_train_state.increment() + self.n_updates += 1 + + return ( + stats, + rng, + train_state, state, obs, actor_carry, critic_carry, ep_stats, + ued_train_state, ued_state, ued_obs, ued_carry, ued_ep_stats, reset_state + ) + + def _concat_multi_agent_obs(self, obs: Dict[str, chex.Array]) -> chex.Array: + """Concatenates a obs dictionary that was built for two env agents. + Doubles the number of parallel_envs, i.e. (num_students, n_parallel, ...) -> (num_students, 2*n_parallel, ...) + """ + return jnp.concatenate([obs['agent_0'][:, :, jnp.newaxis, :], obs['agent_1'][:, :, jnp.newaxis, :]], axis=2) + + # def _double_world_state(self, world_state: chex.Array) -> chex.Array: + # """Doubles the world state to simulate two agents. + # Doubles the number of parallel_envs, i.e. (num_students, n_parallel, ...) -> (num_students, 2*n_parallel, ...) + # """ + # return jnp.concatenate([world_state[:, :, jnp.newaxis, :], world_state[:, :, jnp.newaxis, :]], axis=2) diff --git a/src/minimax/runners_ma/plr_runner.py b/src/minimax/runners_ma/plr_runner.py new file mode 100644 index 0000000..e6cf373 --- /dev/null +++ b/src/minimax/runners_ma/plr_runner.py @@ -0,0 +1,578 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from enum import Enum + +import einops +import numpy as np +import jax +import jax.numpy as jnp + +import minimax.envs as envs +from minimax.runners_ma.dr_runner import DRRunner +from minimax.util import pytree as _tree_util +from minimax.util.rl import ( + UEDScore, + compute_ued_scores, + PopPLRManager +) + + +class MutationCriterion(Enum): + BATCH = 0 + EASY = 1 + HARD = 2 + + +class PLRRunner(DRRunner): + def __init__( + self, + *, + replay_prob=0.5, + buffer_size=100, + staleness_coef=0.3, + use_score_ranks=True, + temp=1.0, + min_fill_ratio=0.5, + use_robust_plr=False, + use_parallel_eval=False, + ued_score='l1_value_loss', + force_unique=False, # Slower if True + mutation_fn=None, + n_mutations=0, + mutation_criterion='batch', + mutation_subsample_size=1, + **kwargs): + use_mutations = mutation_fn is not None + if use_parallel_eval: + replay_prob = 1.0 # Replay every rollout cycle + # Force batch mutations (no UED scores) + mutation_criterion = 'batch' + self._n_parallel_batches = 3 if use_mutations else 2 + kwargs['n_parallel'] *= self._n_parallel_batches + + super().__init__(**kwargs) + + self.replay_prob = replay_prob + self.buffer_size = buffer_size + self.staleness_coef = staleness_coef + self.temp = temp + self.use_score_ranks = use_score_ranks + self.min_fill_ratio = min_fill_ratio + self.use_robust_plr = use_robust_plr + self.use_parallel_eval = use_parallel_eval + self.ued_score = UEDScore[ued_score.upper()] + + self.use_mutations = use_mutations + if self.use_mutations: + self.mutation_fn = envs.get_mutator( + self.benv.env_name, mutation_fn) + else: + self.mutation_fn = None + self.n_mutations = n_mutations + self.mutation_criterion = MutationCriterion[mutation_criterion.upper()] + self.mutation_subsample_size = mutation_subsample_size + + self.force_unique = force_unique + if force_unique: + self.comparator_fn = envs.get_comparator(self.benv.env_name) + else: + self.comparator_fn = None + + if mutation_fn is not None and mutation_criterion != 'batch': + assert self.n_parallel % self.mutation_subsample_size == 0, \ + 'Number of parallel envs must be divisible by mutation subsample size.' + + def reset(self, rng): + runner_state = list(super().reset(rng)) + rng = runner_state[0] + runner_state[0], subrng = jax.random.split(rng) + example_state = self.benv.env.reset(rng)[1] + + self.plr_mgr = PopPLRManager( + n_agents=self.n_students, + example_level=example_state, + ued_score=self.ued_score, + replay_prob=self.replay_prob, + buffer_size=self.buffer_size, + staleness_coef=self.staleness_coef, + temp=self.temp, + use_score_ranks=self.use_score_ranks, + min_fill_ratio=self.min_fill_ratio, + use_robust_plr=self.use_robust_plr, + use_parallel_eval=self.use_parallel_eval, + comparator_fn=self.comparator_fn, + n_devices=self.n_devices + ) + plr_buffer = self.plr_mgr.reset(self.n_students) + + train_state = runner_state[1] + train_state = train_state.replace(plr_buffer=plr_buffer) + if self.n_devices == 1: + runner_state[1] = train_state + else: + plr_buffer = jax.tree_map(lambda x: x.repeat( + self.n_devices, 1), plr_buffer) # replicate plr buffer + # Return PLR buffer directly to make shmap easier + runner_state += (plr_buffer,) + + self.dummy_eval_output = self._create_dummy_eval_output(train_state) + + return tuple(runner_state) + + def _create_dummy_eval_output(self, train_state): + rng, * \ + vrngs = jax.random.split(jax.random.PRNGKey(0), self.n_students+1) + obs, state, extra = self.benv.reset(jnp.array(vrngs)) + + ep_stats = self.rolling_stats.reset_stats( + batch_shape=(self.n_students, self.n_parallel*self.n_eval)) + + ued_scores = jnp.zeros((self.n_students, self.n_parallel)) + + if self.student_pop.agent.is_recurrent: + actor_carry, critic_carry = self.zero_carry, self.zero_carry + else: + actor_carry, critic_carry = None, None + rollout = self.student_rollout.reset() + + # Map over the n_env_agents dimension in this multi agent rl setting. + # Dimensions are (n_students, n_parallel, n_env_agents, ...) + value, _ = jax.vmap(self.student_pop.get_value, in_axes=(None, 2, 2))( + jax.lax.stop_gradient(train_state.critic_params), + self._concat_multi_agent_obs(obs["world_state"]), + critic_carry, + ) + + value = einops.rearrange( + value, + "n_env_agents n_students n_parallel -> n_students n_parallel n_env_agents") + + jax.debug.print( + "train_state.shaped_reward_coeff {s}", s=train_state.shaped_reward_coeff) + + train_batch = self.student_rollout.get_batch( + rollout, value, train_state.shaped_reward_coeff + ) + + return ( + rng, + train_state, + state, + state, + obs, + actor_carry, + critic_carry, + extra, + ep_stats, + state, + train_batch, + ued_scores + ) + + @partial(jax.jit, static_argnums=(0, 8)) + def _eval_and_update_plr( + self, + rng, + levels, + level_idxs, + train_state, + update_plr, + parent_idxs=None, + dupe_mask=None, + fake=False): + # Collect rollout and optionally update plr buffer + # Returns train_batch and ued_scores + if fake: + dummy_eval_output = list(self.dummy_eval_output) + dummy_eval_output[1] = train_state + return tuple(dummy_eval_output) + + rollout_batch_shape = (self.n_students, self.n_parallel*self.n_eval) + obs, state, extra = self.benv.set_state(levels) + ep_stats = self.rolling_stats.reset_stats( + batch_shape=rollout_batch_shape) + + rollout_start_state = state + + done = jnp.zeros(rollout_batch_shape, dtype=jnp.bool_) + if self.student_pop.agent.is_recurrent: + actor_carry = self.zero_carry + critic_carry = self.zero_carry + else: + actor_carry, critic_carry = None, None + + rng, subrng = jax.random.split(rng) + start_state = state + rollout, state, start_state, obs, actor_carry, critic_carry, extra, ep_stats, train_state = \ + self._rollout_students( + subrng, + train_state, + state, + start_state, + obs, + actor_carry, + critic_carry, + done, + extra, + ep_stats + ) + + reward = rollout["rewards"].sum(axis=1).mean(-1)[:, :, jnp.newaxis] + shaped_reward = rollout["shaped_rewards"].sum( + axis=1).mean(-1)[:, :, jnp.newaxis] + + ep_stats["reward"] = reward + ep_stats["shaped_reward"] = shaped_reward + ep_stats["shaped_reward_scaled_by_shaped_reward_coeff"] = shaped_reward * \ + train_state.shaped_reward_coeff + ep_stats["reward_p_shaped_reward_scaled"] = reward + shaped_reward * \ + train_state.shaped_reward_coeff + + shared_obs = self._concat_multi_agent_obs(obs['world_state']) + value, _ = jax.vmap(self.student_pop.get_value, in_axes=(None, 2, 2))( + jax.lax.stop_gradient(train_state.critic_params), + shared_obs, + critic_carry + ) + + value = einops.rearrange( + value, "n_env_agents n_students n_parallel -> n_students n_parallel n_env_agents") + train_batch = self.student_rollout.get_batch( + rollout, + value, + train_state.shaped_reward_coeff + ) + + # Update PLR buffer + if self.ued_score == UEDScore.MAX_MC: + max_returns = jax.vmap(lambda x, y: x.at[y].get())( + train_state.plr_buffer.max_returns, level_idxs) + max_returns = jnp.where( + jnp.greater_equal(level_idxs, 0), + max_returns, + jnp.full_like(max_returns, -jnp.inf) + ) + ued_info = {'max_returns': max_returns} + else: + ued_info = None + ued_scores, ued_score_info = compute_ued_scores( + self.ued_score, train_batch, self.n_eval, info=ued_info, ignore_val=-jnp.inf, per_agent=True) + next_plr_buffer = self.plr_mgr.update( + train_state.plr_buffer, + levels=levels, + level_idxs=level_idxs, + ued_scores=ued_scores, + dupe_mask=dupe_mask, + info=ued_score_info, + ignore_val=-jnp.inf, + parent_idxs=parent_idxs) + + next_plr_buffer = jax.vmap( + lambda update, new, prev: jax.tree_map( + lambda x, y: jax.lax.select(update, x, y), new, prev) + )(update_plr, next_plr_buffer, train_state.plr_buffer) + + train_state = train_state.replace(plr_buffer=next_plr_buffer) + + return ( + rng, + train_state, + state, + start_state, + obs, + actor_carry, + critic_carry, + extra, + ep_stats, + rollout_start_state, + train_batch, + ued_scores, + ) + + @partial(jax.jit, static_argnums=(0,)) + def _mutate_levels(self, rng, levels, level_idxs, ued_scores=None): + if not self.use_mutations: + return levels, level_idxs, jnp.full_like(level_idxs, -1) + + def upsample_levels(levels, level_idxs, subsample_idxs): + subsample_idxs = subsample_idxs.repeat( + self.n_parallel//self.mutation_subsample_size, -1) + parent_idxs = level_idxs.take(subsample_idxs) + levels = jax.vmap( + lambda x, y: jax.tree_map( + lambda _x: jnp.array(_x).take(y, 0), x) + )(levels, parent_idxs) + + return levels, parent_idxs + + if self.mutation_criterion == MutationCriterion.BATCH: + parent_idxs = level_idxs + + if self.mutation_criterion == MutationCriterion.EASY: + _, top_level_idxs = jax.lax.approx_min_k( + ued_scores, self.mutation_subsample_size) + levels, parent_idxs = upsample_levels( + levels, level_idxs, top_level_idxs) + + elif self.mutation_criterion == MutationCriterion.HARD: + _, top_level_idxs = jax.lax.approx_max_k( + ued_scores, self.mutation_subsample_size) + levels, parent_idxs = upsample_levels( + levels, level_idxs, top_level_idxs) + + n_parallel = level_idxs.shape[-1] + vrngs = jax.vmap(lambda subrng: jax.random.split(subrng, n_parallel))( + jax.random.split(rng, self.n_students) + ) + + mutated_levels = jax.vmap( + lambda *args: jax.vmap(self.mutation_fn, + in_axes=(0, None, 0, None))(*args), + in_axes=(0, None, 0, None) + )(vrngs, self.benv.env_params, levels, self.n_mutations) + + # Mutated levels do not have existing idxs in the PLR buffer. + mutated_level_idxs = jnp.full((self.n_students, n_parallel), -1) + + return mutated_levels, mutated_level_idxs, parent_idxs + + def _efficient_grad_update(self, rng, train_state, train_batch, is_replay): + # PPOAgent vmaps over the train state and batch. Batch must be N x EM + skip_grad_update = jnp.logical_and(self.use_robust_plr, ~is_replay) + + if self.n_students == 1: + train_state, stats = jax.lax.cond( + skip_grad_update[0], + partial(self.student_pop.update, fake=True), + self.student_pop.update, + *(rng, train_state, train_batch) + ) + elif self.n_students > 1: # Have to vmap all students + take only students that need updates + _, dummy_stats = jax.vmap( + lambda *_: self.student_pop.agent.get_empty_update_stats())(np.arange(self.n_students)) + _train_state, stats = self.student.update( + rng, train_state, train_batch) + train_state, stats = jax.vmap(lambda cond, x, y: + jax.tree_map(lambda _cond, _x, _y: jax.lax.select(_cond, _x, _y), cond, x, y))( + is_replay, (train_state, + stats), (_train_state, dummy_stats) + ) + + return train_state, stats + + @partial(jax.jit, static_argnums=(0,)) + def _compile_stats(self, update_stats, ep_stats, env_metrics=None, plr_stats=None, shaped_reward_coeff=None): + stats = super()._compile_stats(update_stats, ep_stats, env_metrics) + + if plr_stats is not None: + plr_stats = jax.vmap(lambda info: jax.tree_map( + lambda x: x.mean(), info))(plr_stats) + + if shaped_reward_coeff is not None: + stats['shaped_reward_coeff'] = shaped_reward_coeff + + if self.n_students > 1: + _plr_stats = {} + for i in range(self.n_students): + _student_plr_stats = jax.tree_util.tree_map( + lambda x: x[i], plr_stats) # for agent0 + _plr_stats.update( + {f'{k}_a{i}': v for k, v in _student_plr_stats.items()}) + plr_stats = _plr_stats + else: + plr_stats = jax.tree_map(lambda x: x[0], plr_stats) + + stats.update({f'plr_{k}': v for k, v in plr_stats.items()}) + + if self.n_devices > 1: + stats = jax.tree_map(lambda x: jax.lax.pmean(x, 'device'), stats) + + return stats + + @partial(jax.jit, static_argnums=(0,)) + def run( + self, + rng, + train_state, + state, + start_state, + obs, + carry=None, + extra=None, + ep_stats=None, + plr_buffer=None): + # If device sharded, load sharded PLR buffer into train state + if self.n_devices > 1: + rng = jax.random.fold_in(rng, jax.lax.axis_index('device')) + train_state = train_state.replace(plr_buffer=plr_buffer) + + # Sample next training levels via PLR + rng, *vrngs = jax.random.split(rng, self.n_students+1) + obs, state, extra = self.benv.reset( + jnp.array(vrngs), self.n_parallel, 1) + + if self.use_parallel_eval: + n_level_samples = self.n_parallel//self._n_parallel_batches + new_levels = jax.tree_map( + lambda x: x.at[:, n_level_samples:2*n_level_samples].get(), state) + else: + n_level_samples = self.n_parallel + new_levels = state + + rng, subrng = jax.random.split(rng) + levels, level_idxs, is_replay, next_plr_buffer = \ + self.plr_mgr.sample(subrng, train_state.plr_buffer, + new_levels, n_level_samples) + train_state = train_state.replace(plr_buffer=next_plr_buffer) + + # If use_parallel_eval=True, need to combine replay and non-replay levels together + # Need to mutate levels as well + parent_idxs = jnp.full((self.n_students, self.n_parallel), -1) + if self.use_parallel_eval: # Parallel ACCEL + new_level_idxs = jnp.full_like(parent_idxs, -1) + + _all_levels = jax.vmap( + lambda x, y: _tree_util.pytree_merge( + x, y, start_idx=n_level_samples, src_len=n_level_samples), + )(state, levels) + _all_level_idxs = jax.vmap( + lambda x, y: _tree_util.pytree_merge( + x, y, start_idx=n_level_samples, src_len=n_level_samples), + )(new_level_idxs, level_idxs) + + if self.use_mutations: + rng, subrng = jax.random.split(rng) + mutated_levels, mutated_level_idxs, _parent_idxs = self._mutate_levels( + subrng, levels, level_idxs) + + fallback_levels = jax.tree_map( + lambda x: x.at[:, -n_level_samples:].get(), state) + fallback_level_idxs = jnp.full_like(mutated_level_idxs, -1) + + mutated_levels = jax.vmap( + lambda cond, x, y: jax.tree_map( + lambda _x, _y: jax.lax.select(cond, _x, _y), x, y + ))(is_replay, mutated_levels, fallback_levels) + + mutated_level_idxs = jax.vmap( + lambda cond, x, y: jax.tree_map( + lambda _x, _y: jax.lax.select(cond, _x, _y), x, y + ))(is_replay, mutated_level_idxs, fallback_level_idxs) + + _parent_idxs = jax.vmap( + lambda cond, x, y: jax.tree_map( + lambda _x, _y: jax.lax.select(cond, _x, _y), x, y + ))(is_replay, _parent_idxs, fallback_level_idxs) + + mutated_levels_start_idx = 2*n_level_samples + _all_levels = jax.vmap( + lambda x, y: _tree_util.pytree_merge( + x, y, start_idx=mutated_levels_start_idx, src_len=n_level_samples), + )(_all_levels, mutated_levels) + _all_level_idxs = jax.vmap( + lambda x, y: _tree_util.pytree_merge( + x, y, start_idx=mutated_levels_start_idx, src_len=n_level_samples), + )(_all_level_idxs, mutated_level_idxs) + parent_idxs = jax.vmap( + lambda x, y: _tree_util.pytree_merge( + x, y, start_idx=mutated_levels_start_idx, src_len=n_level_samples), + )(parent_idxs, _parent_idxs) + + levels = _all_levels + level_idxs = _all_level_idxs + + # dedupe levels, move into PLR buffer logic + if self.force_unique: + level_idxs, dupe_mask = self.plr_mgr.dedupe_levels( + next_plr_buffer, levels, level_idxs) + else: + dupe_mask = None + + # Evaluate levels + update PLR + result = self._eval_and_update_plr( + rng, levels, level_idxs, train_state, update_plr=jnp.array([True]*self.n_students), parent_idxs=parent_idxs, dupe_mask=dupe_mask) + rng, train_state, state, start_state, obs, actor_carry, critic_carry, extra, ep_stats, \ + rollout_start_state, train_batch, ued_scores = result + + if self.use_parallel_eval: + replay_start_idx = self.n_eval*n_level_samples + replay_end_idx = 2*replay_start_idx + train_batch = jax.vmap( + lambda x: jax.tree_map( + lambda _x: _x.at[:, replay_start_idx:replay_end_idx].get(), x) + )(train_batch) + + # Gradient update + rng, subrng = jax.random.split(rng) + train_state, update_stats = self._efficient_grad_update( + subrng, train_state, train_batch, is_replay) + + # Mutation step + use_mutations = jnp.logical_and(self.use_mutations, is_replay) + # Already mutated above in parallel + use_mutations = jnp.logical_and( + use_mutations, not self.use_parallel_eval) + rng, arng, brng = jax.random.split(rng, 3) + + mutated_levels, mutated_level_idxs, parent_idxs = jax.lax.cond( + use_mutations.any(), + self._mutate_levels, + lambda *_: (levels, level_idxs, jnp.full_like(level_idxs, -1)), + *(arng, levels, level_idxs, ued_scores) + ) + + mutated_dupe_mask = jnp.zeros_like(mutated_level_idxs, dtype=jnp.bool_) + if self.force_unique: # Should move into update plr logic + mutated_level_idxs, mutated_dupe_mask = jax.lax.cond( + use_mutations.any(), + self.plr_mgr.dedupe_levels, + lambda *_: (mutated_level_idxs, mutated_dupe_mask), + *(next_plr_buffer, mutated_levels, mutated_level_idxs) + ) + + mutation_eval_result = jax.lax.cond( + use_mutations.any(), + self._eval_and_update_plr, + partial(self._eval_and_update_plr, fake=True), + *(brng, mutated_levels, mutated_level_idxs, train_state, use_mutations, parent_idxs, mutated_dupe_mask) + ) + train_state = mutation_eval_result[1] + + # Collect training env metrics + if self.track_env_metrics: + env_metrics = self.benv.get_env_metrics(levels) + else: + env_metrics = None + + plr_stats = self.plr_mgr.get_metrics(train_state.plr_buffer) + + stats = self._compile_stats( + update_stats, ep_stats, env_metrics, plr_stats, shaped_reward_coeff=train_state.shaped_reward_coeff) + + if self.n_devices > 1: + plr_buffer = train_state.plr_buffer + train_state = train_state.replace(plr_buffer=None) + + train_state = train_state.increment() + stats.update(dict(n_updates=train_state.n_updates[0])) + + return ( + stats, + rng, + train_state, + state, + start_state, + obs, + carry, + extra, + ep_stats, + plr_buffer, + rollout_start_state, + ) diff --git a/src/minimax/runners_ma/xp_runner.py b/src/minimax/runners_ma/xp_runner.py new file mode 100644 index 0000000..9600280 --- /dev/null +++ b/src/minimax/runners_ma/xp_runner.py @@ -0,0 +1,377 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +import os +import time + +import imageio +import numpy as np +import jax +import jax.numpy as jnp +from jax.experimental import mesh_utils +from jax.experimental.shard_map import shard_map + + +from .eval_runner import EvalRunner +from .dr_runner import DRRunner +from .paired_runner import PAIREDRunner +from .plr_runner import PLRRunner +import minimax.envs as envs +import minimax.models as models +import minimax.agents as agents +from minimax.envs.viz.overcooked_visualizer import OvercookedVisualizer + + +class RunnerInfo: + def __init__( + self, + runner_cls, + is_ued=False): + self.runner_cls = runner_cls + self.is_ued = is_ued + + +RUNNER_INFO = { + 'dr': RunnerInfo( + runner_cls=DRRunner, + ), + 'plr': RunnerInfo( + runner_cls=PLRRunner, + ), + 'paired': RunnerInfo( + runner_cls=PAIREDRunner, + is_ued=True + ) +} + + +class ExperimentRunner: + def __init__( + self, + train_runner, + env_name, + agent_rl_algo, + student_model_name, + student_critic_model_name, + student_agent_kind="mappo", + teacher_model_name=None, + train_runner_kwargs={}, + env_kwargs={}, + ued_env_kwargs={}, + student_rl_kwargs={}, + teacher_rl_kwargs={}, + student_model_kwargs={}, + teacher_model_kwargs={}, + eval_kwargs={}, + eval_env_kwargs={}, + n_devices=1, + shaped_reward_steps=0, + shaped_reward_n_updates=0, + xpid=None + ): + self.env_name = env_name + self.agent_rl_algo = agent_rl_algo + self.is_ued = RUNNER_INFO[train_runner].is_ued + self.xpid = xpid + self.shaped_reward_steps = shaped_reward_steps + self.shaped_reward_n_updates = shaped_reward_n_updates + + dummy_env = envs.make( + env_name, + env_kwargs, + ued_env_kwargs)[0] + + # ---- Make agent ---- + if student_agent_kind == 'mappo': + student_model_kwargs['output_dim'] = dummy_env.action_space().n + student_actor = models.make( + env_name=env_name, + model_name=student_model_name, + **student_model_kwargs + ) + + student_model_kwargs['output_dim'] = 1 + student_critic = models.make( + env_name=env_name, + model_name=student_critic_model_name, + **student_model_kwargs + ) + + student_agent = agents.MAPPOAgent( + actor=student_actor, + critic=student_critic, + n_devices=n_devices, + **student_rl_kwargs + ) + else: + raise ValueError( + f"Unknown student agent kind: {student_agent_kind}") + + # ---- Handle UED-related settings ---- + if self.is_ued: + max_teacher_steps = dummy_env.ued_max_episode_steps() + teacher_model_kwargs['n_scalar_embeddings'] = max_teacher_steps + teacher_model_kwargs['max_scalar'] = max_teacher_steps + teacher_model_kwargs['output_dim'] = dummy_env.ued_action_space().n + + teacher_model = models.make( + env_name=env_name, + model_name=teacher_model_name, + **teacher_model_kwargs + ) + + teacher_agent = agents.PPOAgent( + model=teacher_model, + n_devices=n_devices, + **teacher_rl_kwargs + ) + + train_runner_kwargs.update(dict( + teacher_agents=[teacher_agent] + )) + train_runner_kwargs.update(dict( + ued_env_kwargs=ued_env_kwargs + )) + + # Debug, tabulate student and teacher model + # import jax.numpy as jnp + # dummy_rng = jax.random.PRNGKey(0) + # obs, _ = dummy_env.reset(dummy_rng) + # # hx = student_actor.initialize_carry(dummy_rng, (1,)) + # ued_obs, _ = dummy_env.reset_teacher(dummy_rng) + # # ued_hx = teacher_model.initialize_carry(dummy_rng, (1,)) + + # obs['image'] = jnp.expand_dims(obs['image'], 0) + # ued_obs['image'] = jnp.expand_dims(ued_obs['image'], 0) + + # print(student_actor.tabulate(dummy_rng, obs, None)) + # print(teacher_model.tabulate(dummy_rng, ued_obs, None)) + + # import pdb + # pdb.set_trace() + + # ---- Set up train runner ---- + runner_cls = RUNNER_INFO[train_runner].runner_cls + + # Set up learning rate annealing parameters + lr_init = train_runner_kwargs.lr + lr_final = train_runner_kwargs.lr_final + lr_anneal_steps = train_runner_kwargs.lr_anneal_steps + + if lr_final is None: + train_runner_kwargs.lr_final = lr_init + if train_runner_kwargs.lr_final == train_runner_kwargs.lr: + train_runner_kwargs.lr_anneal_steps = 0 + + use_shaped_reward = (shaped_reward_steps is not None and shaped_reward_steps > 0) or ( + shaped_reward_n_updates is not None and shaped_reward_n_updates > 0) + + self.runner = runner_cls( + env_name=env_name, + env_kwargs=env_kwargs, + student_agents=[student_agent], + student_agent_kind=student_agent_kind, + n_devices=n_devices, + shaped_reward=use_shaped_reward, + **train_runner_kwargs) + + # ---- Make eval runner ---- + if eval_kwargs.get('env_names') is None: + self.eval_runner = None + else: + self.eval_runner = EvalRunner( + pop=self.runner.student_pop, + env_kwargs=eval_env_kwargs, + **eval_kwargs) + + self._start_tick = 0 + + # ---- Set up device parallelism ---- + self.n_devices = n_devices + if n_devices > 1: + dummy_runner_state = self.reset_train_runner(jax.random.PRNGKey(0)) + self._shmap_run = self._make_shmap_run(dummy_runner_state) + else: + self._shmap_run = None + + @partial(jax.jit, static_argnums=(0,)) + def step(self, runner_state, evaluate=False): + if self.n_devices > 1: + run_fn = self._shmap_run + else: + run_fn = self.runner.run + + stats, *runner_state = run_fn(*runner_state) + + rng = runner_state[0] + rng, subrng = jax.random.split(rng) + + if self.eval_runner is not None: + params = runner_state[1].actor_params + eval_stats = jax.lax.cond( + evaluate, + self.eval_runner.run, + self.eval_runner.fake_run, + *(subrng, params) + ) + else: + eval_stats = {} + + return stats, eval_stats, rng, *runner_state[1:] + + def _make_shmap_run(self, runner_state): + devices = mesh_utils.create_device_mesh((self.n_devices,)) + mesh = Mesh(devices, axis_names=('device')) + + in_specs, out_specs = self.runner.get_shmap_spec() + + return partial(shard_map, + mesh=mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False + )(self.runner.run) + + def train( + self, + rng, + agent_algo='ppo', + algo_runner='dr', + n_total_updates=1000, + logger=None, + log_interval=1, + test_interval=1, + checkpoint_interval=0, + archive_interval=0, + archive_init_checkpoint=False, + from_last_checkpoint=False + ): + """ + Entry-point for training + """ + # Load checkpoint if any + runner_state = self.runner.reset(rng) + + if from_last_checkpoint: + last_checkpoint_state = logger.load_last_checkpoint_state() + if last_checkpoint_state is not None: + runner_state = self.runner.load_checkpoint_state( + runner_state, + last_checkpoint_state + ) + self._start_tick = runner_state[1].n_iters[0] + + # Archive initialization weights if necessary + if archive_init_checkpoint: + logger.checkpoint( + self.runner.get_checkpoint_state(runner_state), + index=0, + archive_interval=1) + + # Train loop + log_on = logger is not None and log_interval > 0 + checkpoint_on = checkpoint_interval > 0 or archive_interval > 0 + train_state = runner_state[1] + + tick = self._start_tick + train_steps = tick*self.runner.step_batch_size * \ + self.runner.n_rollout_steps*self.n_devices + real_train_steps = train_steps//self.runner.n_students + + while (train_state.n_updates < n_total_updates).any(): + evaluate = test_interval > 0 and (tick+1) % test_interval == 0 + + start = time.time() + stats, eval_stats, *runner_state = \ + self.step(runner_state, evaluate) + end = time.time() + + start_state = runner_state[-1] + runner_state = runner_state[:-1] + + if evaluate: + stats.update(eval_stats) + else: + stats.update({k: None for k in eval_stats.keys()}) + + train_state = runner_state[1] + + dsteps = self.runner.step_batch_size*self.runner.n_rollout_steps*self.n_devices + real_dsteps = dsteps//self.runner.n_students + train_steps += dsteps + real_train_steps += real_dsteps + + if (self.shaped_reward_steps is not None and self.shaped_reward_steps > 0) or (self.shaped_reward_n_updates is not None and self.shaped_reward_n_updates > 0): + if self.shaped_reward_n_updates: # Meassure based on n_updates + new_shaped_reward_coeff_value = max( + 0.0, 1.0 - (train_state.n_updates[0]/self.shaped_reward_n_updates)) + else: # Meassure based on steps in the env + new_shaped_reward_coeff_value = max( + 0.0, 1.0 - (real_train_steps/self.shaped_reward_steps)) + + new_shaped_reward_coeff = jnp.full( + runner_state[1].shaped_reward_coeff.shape, fill_value=new_shaped_reward_coeff_value) + jax.debug.print("Shaped reward coeff: {a}, real_dsteps: {b}, shaped_reward_steps: {c}", + a=new_shaped_reward_coeff, b=real_dsteps, c=self.shaped_reward_steps) + # runner_state[1] is the training state object where the shaped reward coefficient is stored + runner_state[1] = runner_state[1].set_new_shaped_reward_coeff( + new_shaped_reward_coeff) + + sps = int(dsteps/(end-start)) + real_sps = int(real_dsteps/(end-start)) + time_per_iter = float(end-start) + stats.update(dict( + steps=train_steps, + sps=sps, + real_steps=real_train_steps, + real_sps=real_sps, + time_per_iter=time_per_iter, + )) + + tick += 1 + + jax.debug.print("-----\n{stats}", stats=stats) + if log_on and tick % log_interval == 0: + logger.log(stats, tick, ignore_val=-np.inf) + + if checkpoint_on and tick > 0: + if tick % checkpoint_interval == 0 \ + or (archive_interval > 0 and tick % archive_interval == 0): + + maze_map = start_state.maze_map + agent_dir_idx = start_state.agent_dir_idx + agent_inv = start_state.agent_inv + for i in range(1): # self.runner.n_parallel + + padding = 4 # Fixed + grid = np.asarray( + maze_map[0, i, padding:-padding, padding:-padding, :]) + # Render the state + frame = OvercookedVisualizer._render_grid( + grid, + tile_size=32, + highlight_mask=None, + agent_dir_idx=agent_dir_idx[0][i], + agent_inv=agent_inv[0][i] + ) + + # Save the numpy frame as image + dir = f"{os.getcwd()}/overcooked_teacher_layout_imgs/{self.xpid}/" + + os.makedirs(os.path.dirname(dir), exist_ok=True) + imageio.imwrite( + dir + f"{tick}_{i}.png", frame) + + # Also produce an image of the teachers env output currently + checkpoint_state = \ + self.runner.get_checkpoint_state(runner_state) + logger.checkpoint( + checkpoint_state, + index=tick, + archive_interval=archive_interval) diff --git a/src/minimax/tests/__init__.py b/src/minimax/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/minimax/tests/base_req_rollout_storage.py b/src/minimax/tests/base_req_rollout_storage.py new file mode 100644 index 0000000..f4f1f31 --- /dev/null +++ b/src/minimax/tests/base_req_rollout_storage.py @@ -0,0 +1,116 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import pytest +import numpy as np +import jax +import jax.numpy as jnp + +from util.rl import RolloutStorage +import envs +import models +import agents + + +class RequiresRolloutStorageTestClass: + def setup_class(self): + # Create maze object + env_name = 'Maze' + self.env, self.env_params = envs.make(env_name) + + self.n_steps = 32 + self.n_agents = 2 + self.n_envs = 3 + self.n_eval = 5 + self.rnn_dim = 4 + self.discount = 0.995 + self.gae_lambda = 0.95 + + self.batch_shape = ( + self.n_agents, + self.n_steps, + self.n_envs*self.n_eval, + ) + + self.t_batch_shape = ( + self.n_agents, + self.n_envs*self.n_eval, + ) + + # Create dummy agent + self.agent_model = models.make( + env_name=env_name, + model_name='default_student_cnn', + recurrent_arch='lstm', + recurrent_hidden_dim=self.rnn_dim + ) + + self.agent = agents.PPOAgent( + model=self.agent_model, + ) + + dummy_rng = jax.random.PRNGKey(0) + self.zero_carry_t = \ + self.agent.init_carry( + dummy_rng, + batch_dims=( + self.n_agents, + self.n_envs*self.n_eval) + ) + + # Initialize RolloutStorage obj + self.rollout_mgr = RolloutStorage( + discount=self.discount, + gae_lambda=self.gae_lambda, + n_envs=self.n_envs, + n_eval=self.n_eval, + n_steps=self.n_steps, + action_space=self.env.action_space(), + obs_space=self.env.observation_space(), + agent=self.agent, + n_agents=self.n_agents + ) + + def setup_method(self): + pass + + def _get_dummy_step_components(self): + t_batch_shape = self.t_batch_shape + + dummy_rng = jax.random.PRNGKey(0) + obs, state = self.env.reset(dummy_rng) + obs = jax.tree_util.tree_map(lambda x: x.repeat( + np.prod(t_batch_shape) + ).reshape( + *t_batch_shape, *x.shape + ), obs + ) + action = jnp.ones( + (*t_batch_shape, *self.env.action_space().shape), + dtype=self.env.action_space().dtype) + + done = jnp.ones( + (*t_batch_shape, *self.env.action_space().shape), dtype=jnp.uint8) + + reward = jnp.ones(t_batch_shape, dtype=jnp.float32) + + log_pis_old = jnp.ones(t_batch_shape, dtype=jnp.float32) + + values_old = jnp.ones(t_batch_shape, dtype=jnp.float32) + + carry = self.zero_carry_t + + return ( + obs, + action, + reward, + done, + log_pis_old, + values_old, + carry + ) \ No newline at end of file diff --git a/src/minimax/tests/dummy_test_envs.py b/src/minimax/tests/dummy_test_envs.py new file mode 100644 index 0000000..5c0f32d --- /dev/null +++ b/src/minimax/tests/dummy_test_envs.py @@ -0,0 +1,130 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections import OrderedDict +from typing import Tuple, Optional + +import jax +from flax import struct +import chex + +from envs import environment +from envs import spaces +from envs.registration import register, register_ued + + +@struct.dataclass +class EnvState: + time: int = 0 + terminal: bool = False + +@struct.dataclass +class EnvParams: + reward_per_step: int = 1.0 + max_episode_steps: int = 250 + + +class DummyRewardEnv(environment.Environment): + def __init__(self, reward_per_step=1.0, max_episode_steps=250): + self.reward_per_step = 1.0 + + self.params = EnvParams( + reward_per_step=reward_per_step, + max_episode_steps=max_episode_steps + ) + + @staticmethod + def align_kwargs(kwargs, other_kwargs): + return kwargs + + def reset_env( + self, + key: chex.PRNGKey, + ) -> Tuple[chex.ArrayTree, EnvState]: + state = EnvState( + time=0, + terminal=False + ) + + return self.get_obs(state), state + + def step_env( + self, + key: chex.PRNGKey, + state: EnvState, + action: int, + ) -> Tuple[chex.Array, EnvState, float, bool, dict]: + next_time = state.time + 1 + done = next_time >= self.params.max_episode_steps + + state = state.replace( + time=next_time, + terminal=done + ) + + return ( + jax.lax.stop_gradient(self.get_obs(state)), + jax.lax.stop_gradient(state), + self.params.reward_per_step, + done, + {}, + ) + + def get_obs(self, state: EnvState) -> chex.ArrayTree: + return OrderedDict(dict(time=state.time)) + + @property + def default_params(self) -> EnvParams: + return EnvParams() + + @property + def name(self) -> str: + return "DummyRewardEnv" + + @property + def num_actions(self) -> int: + return len(self.action_set) + + def action_space(self) -> spaces.Discrete: + return spaces.Discrete(1) + + def observation_space(self) -> spaces.Dict: + return spaces.Dict({ + "time": spaces.Discrete(self.params.max_episode_steps), + }) + + def state_space(self) -> spaces.Dict: + return spaces.Dict({ + "time": spaces.Discrete(self.params.max_episode_steps), + "terminal": spaces.Discrete(2), + }) + + def max_episode_steps(self) -> int: + return self.params.max_episode_steps + + # UED-specific + def get_env_instance( + self, + key: chex.PRNGKey, + state: EnvState + ) -> chex.ArrayTree: + return state + + def set_env_instance(self, encoding: chex.ArrayTree): + state = encoding + return self.get_obs(state), state + + +# Register the env as its own UED env +if hasattr(__loader__, 'name'): + module_path = __loader__.name +elif hasattr(__loader__, 'fullname'): + module_path = __loader__.fullname + +register(env_id='DummyRewardEnv', entry_point=module_path + ':DummyRewardEnv') +register_ued(env_id='DummyRewardEnv', entry_point=module_path + ':DummyRewardEnv') diff --git a/src/minimax/tests/test_rollout_storage.py b/src/minimax/tests/test_rollout_storage.py new file mode 100644 index 0000000..3fdd125 --- /dev/null +++ b/src/minimax/tests/test_rollout_storage.py @@ -0,0 +1,163 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import pytest +import numpy as np +import jax +import jax.numpy as jnp + +from tests.base_req_rollout_storage import RequiresRolloutStorageTestClass + + +class TestRolloutStorage(RequiresRolloutStorageTestClass): + def test_reset(self): + rollout = self.rollout_mgr.reset() + + batch_shape = self.batch_shape + obs_space = self.env.observation_space() + for k,v in rollout['obs'].items(): + assert rollout['obs'][k].shape == ( + *batch_shape, + *obs_space.spaces[k].shape + ) + assert (rollout['obs'][k] > 0).sum() == 0 + + assert rollout['actions'].shape == ( + *batch_shape, + *self.env.action_space().shape + ) + + assert rollout['rewards'].shape == batch_shape + assert (rollout['rewards'] > 0).sum() == 0 + + assert rollout['dones'].shape == batch_shape + assert (rollout['dones'] > 0).sum() == 0 + + assert rollout['log_pis_old'].shape == batch_shape + assert (rollout['log_pis_old'] > 0).sum() == 0 + + assert rollout['values_old'].shape == batch_shape + assert (rollout['values_old'] > 0).sum() == 0 + + assert rollout['_t'].shape == (self.n_agents,) + assert (rollout['_t'] > 0).sum() == 0 + + assert rollout['carry'][0].shape == (*batch_shape, self.rnn_dim) + assert rollout['carry'][1].shape == (*batch_shape, self.rnn_dim) + assert jax.tree_util.tree_structure(rollout['carry'][0].at[0,0,0].get()) \ + == jax.tree_util.tree_structure(self.zero_carry_t[0]) + assert jax.tree_util.tree_structure(rollout['carry'][1].at[0,0,0].get()) \ + == jax.tree_util.tree_structure(self.zero_carry_t[1]) + assert (rollout['carry'][0] > 0).sum() == 0 + assert (rollout['carry'][1] > 0).sum() == 0 + + def test_append(self): + # Make sure appending the full rollout length looks right + n_appends = 10 + + t_batch_shape = self.t_batch_shape + + rollout = self.rollout_mgr.reset() + + step = self._get_dummy_step_components() + + for i in range(n_appends): + rollout = self.rollout_mgr.append(rollout, *step) + + t_batch_size = np.prod(t_batch_shape) + + assert rollout['actions'].sum() == t_batch_size*n_appends + assert rollout['rewards'].sum() == t_batch_size*n_appends + assert rollout['dones'].sum() == t_batch_size*n_appends + assert rollout['log_pis_old'].sum() == t_batch_size*n_appends + assert rollout['values_old'].sum() == t_batch_size*n_appends + assert rollout['_t'].mean() == n_appends + + t_overshoot = 2 + for t in range(self.n_steps - n_appends + t_overshoot): + rollout = self.rollout_mgr.append(rollout, *step) + + assert rollout['_t'].mean() == t_overshoot + + def test_compute_gae(self): + # Set up placeholder values + (obs, + action, + _, + _, + log_pi, + _, + carry) = self._get_dummy_step_components() + + # Mark episode done every 8 steps + batch_shape = self.batch_shape + done = jnp.zeros(batch_shape, dtype=jnp.uint8) + done = done.at[:,jnp.arange(4,self.n_steps,4),:].set(1) + + # Reward of 10 at the end of every episode + reward = done*10 + + # Predict 0.1 at every time step + value = jnp.ones(batch_shape)*0.1 + + # Last value is 1 + last_value = jnp.ones(self.t_batch_shape) + + advantages, targets = jax.vmap(self.rollout_mgr.compute_gae)( + value, reward, done, last_value + ) + + _advantages = jnp.zeros(batch_shape) + + next_value = last_value + next_advantage = np.zeros_like(advantages.at[:,0,:].get()) + for t in np.arange(self.n_steps)[::-1]: + _done = done.at[:,t,:].get() + cur_value = value.at[:,t,:].get() + td = reward.at[:,t,:].get() + self.discount*next_value*(1-_done) - cur_value + _advantages = \ + _advantages.at[:,t,:].set(td + self.discount*self.gae_lambda*(1-_done)*next_advantage) + next_advantage = _advantages.at[:,t,:].get() + next_value = cur_value + + _targets = _advantages + value + + assert (_advantages != advantages).sum() == 0 + assert (_targets != targets).sum() == 0 + + def test_get_batch(self): + rollout = self.rollout_mgr.reset() + step = self._get_dummy_step_components() + for i in range(self.n_steps): + rollout = self.rollout_mgr.append(rollout, *step) + + last_value = jnp.ones(self.t_batch_shape) + + batch = self.rollout_mgr.get_batch(rollout, last_value) + + for k,v in batch.obs.items(): + assert (batch.obs[k] != rollout['obs'][k]).sum() == 0 + + assert (batch.actions != rollout['actions']).sum() == 0 + assert (batch.dones != rollout['dones']).sum() == 0 + assert (batch.rewards != rollout['rewards']).sum() == 0 + assert (batch.log_pis != rollout['log_pis_old']).sum() == 0 + assert (batch.values != rollout['values_old']).sum() == 0 + + assert ((batch.advantages + batch.values) != batch.targets).sum() == 0 + + assert (batch.carry[0] != rollout['carry'][0]).sum() == 0 + assert (batch.carry[1] != rollout['carry'][1]).sum() == 0 + + def test_set_final_reward(self): + rollout = self.rollout_mgr.reset() + + final_reward = jnp.ones(self.t_batch_shape)*3 + rollout = self.rollout_mgr.set_final_reward(rollout, final_reward) + + assert (rollout['rewards'].at[:,-1,:].get() != final_reward).sum() == 0 diff --git a/src/minimax/tests/test_ued_scores.py b/src/minimax/tests/test_ued_scores.py new file mode 100644 index 0000000..1323a31 --- /dev/null +++ b/src/minimax/tests/test_ued_scores.py @@ -0,0 +1,74 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import pytest +import numpy as np +import jax +import jax.numpy as jnp + +import util.rl.ued_scores as _ued_scores + +from tests.base_req_rollout_storage import RequiresRolloutStorageTestClass + + +class TestUEDScores(RequiresRolloutStorageTestClass): + def test_compute_ued_scores_returns(self): + (obs, + action, + done, + reward, + log_pi, + value, + carry) = self._get_dummy_step_components() + + # Mark episode done every 8 steps + batch_shape = self.batch_shape + dones = jnp.zeros(batch_shape, dtype=jnp.uint8) + dones = dones.at[:,jnp.arange(4,self.n_steps,4),:].set(1) + + # Reward of 10 at the end of every episode + rewards = jnp.zeros_like(dones, dtype=jnp.float32) + rewards = rewards.at[:,jnp.arange(4,self.n_steps,4),:].set( + jnp.arange(1,8,dtype=jnp.float32).reshape(1,7,1) + ) + + rollout = self.rollout_mgr.reset() + for t in range(self.n_steps): + rollout = self.rollout_mgr.append( + rollout, + obs, + action, + rewards.at[:,t,:].get(), + dones.at[:,t,:].get(), + log_pi, + value, + carry + ) + + score_name = _ued_scores.UEDScore.RETURN + last_value = jnp.zeros(self.t_batch_shape) + batch = self.rollout_mgr.get_batch(rollout, last_value) + ued_score, _ = _ued_scores.compute_ued_scores( + score_name, batch, n_eval=self.n_eval) + + n_agents, n_steps, flat_batch_size = batch.dones.shape + pop_batch_shape = (n_agents, n_steps, flat_batch_size//self.n_eval, self.n_eval) + batch = jax.tree_util.tree_map(lambda x: x.reshape(*pop_batch_shape, *x.shape[3:]), batch) + mean_env_returns_per_agent, max_env_returns_per_agent, _ = \ + jax.vmap(_ued_scores._compute_ued_scores, in_axes=(None, 0))( + score_name, batch + ) + + mean_return = mean_env_returns_per_agent.mean(0) + max_return = max_env_returns_per_agent.max(0) + + assert (mean_return != ued_score).sum() == 0 + + batch_size = self.n_agents*self.n_envs + assert mean_env_returns_per_agent.sum(0).sum(0) == jnp.arange(1,8).mean()*batch_size + assert max_env_returns_per_agent.sum(0).sum(0) == 7*batch_size diff --git a/src/minimax/tests/test_wrappers.py b/src/minimax/tests/test_wrappers.py new file mode 100644 index 0000000..19ee4de --- /dev/null +++ b/src/minimax/tests/test_wrappers.py @@ -0,0 +1,146 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import pytest +import numpy as np +import jax +import jax.numpy as jnp + +import tests.dummy_test_envs as dummy_test_envs +import envs +from envs.environment_ued import UEDEnvironment +from envs.wrappers import EnvWrapper +from envs.wrappers import UEDEnvWrapper + + +class TestEnvWrapper: + def setup_class(self): + # Set up environment + wrapper + env_kwargs = dict( + reward_per_step=1, + max_episode_steps=3 + ) + env_kwargs = env_kwargs + ued_env_kwargs = env_kwargs + self.env_kwargs = env_kwargs + + self.env, self.env_params, self.ued_params = \ + envs.make( + 'DummyRewardEnv', + env_kwargs=env_kwargs, + ued_env_kwargs=ued_env_kwargs, + **self._get_wrappers() + ) + + self.dummy_rng = jax.random.PRNGKey(0) + + @staticmethod + def _get_wrappers(): + return { + 'wrappers':['env_wrapper'], + 'ued_wrappers':['ued_env_wrapper'] + } + + +class TestDefaultEnvWrapper(TestEnvWrapper): + def setup_class(self): + # Set up environment + wrapper + env_kwargs = dict( + reward_per_step=1, + max_episode_steps=2 + ) + env_kwargs = env_kwargs + ued_env_kwargs = env_kwargs + + self.env, self.env_params, self.ued_params = \ + envs.make( + 'DummyRewardEnv', + env_kwargs=env_kwargs, + ued_env_kwargs=ued_env_kwargs, + **self._get_wrappers() + ) + + self.dummy_rng = jax.random.PRNGKey(0) + + def test_base_env(self): + assert isinstance(self.env.base_env, UEDEnvironment) + + def test_reset_extra(self): + extra = self.env.reset_extra() + assert len(extra) == 0 + + def test_step(self): + obs, state, extra = self.env.reset(self.dummy_rng) + extra = self.env.step(self.dummy_rng, state, 0)[-1] + assert len(extra) == 0 + + def test_reset(self): + obs, state, extra = self.env.reset(self.dummy_rng) + assert len(extra) == 0 + + def test_reset_env_instance(self): + _, ued_state = self.env.ued_env.reset(self.dummy_rng) + instance = self.env.ued_env.get_env_instance(self.dummy_rng, ued_state) + extra = self.env.set_env_instance(instance)[-1] + assert len(extra) == 0 + + def reset_teacher(self): + out = self.env.reset_teacher(self.dummy_rng) + assert len(out) == 2 + + def step_teacher(self): + _, ued_state = self.env.reset_teacher(self.dummy_rng) + out = self.env.step_teacher(self.dummy_rng, ued_state, 0) + assert len(out) == 5 + + def reset_student(self): + _, ued_state = self.env.reset_teacher(self.dummy_rng) + _, state, extra = self.env.reset_student(self.dummy_rng, ued_state) + assert len(extra) == 0 + + +class TestMonitorReturnWrapper(TestEnvWrapper): + @staticmethod + def _get_wrappers(): + return { + 'wrappers':['monitor_return'] + } + + def test_wrap_level(self): + assert self.env._wrap_level == 1 + + def test_reset_extra(self): + extra = self.env.reset_extra() + assert len(extra) == 1 and extra['ep_return'] == 0 + + def test_get_monitored_metrics(self): + metrics = self.env.get_monitored_metrics() + assert len(metrics) == 1 and 'return' in metrics + + def test_reset(self): + _, _, extra = self.env.reset(self.dummy_rng) + assert len(extra) == 1 and extra['ep_return'] == 0 + + def test_step(self): + obs, state, extra = self.env.reset(self.dummy_rng) + + n_steps = 2 + return_ = 0 + for i in range(n_steps): + _, state, r, _, _, extra = self.env.step( + self.dummy_rng, state, 0, extra=extra) + return_ += r + + assert extra['ep_return'] == self.env_kwargs['reward_per_step']*n_steps + + # Finish the episode + _, state, r, _, info, extra = self.env.step( + self.dummy_rng, state, 0, extra=extra) + + assert extra['ep_return'] == 0 + assert info['return'] == self.env_kwargs['reward_per_step']*(n_steps+1) \ No newline at end of file diff --git a/src/minimax/train.py b/src/minimax/train.py new file mode 100644 index 0000000..99beadd --- /dev/null +++ b/src/minimax/train.py @@ -0,0 +1,97 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import copy + +# os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=2' +# os.environ['JAX_TRACEBACK_FILTERING'] = 'off' +# os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.70' + +import jax +import wandb + + +from minimax.util.loggers import Logger +from .arguments import parser + + +if __name__ == '__main__': + with jax.disable_jit(False): + args = parser.parse_args(preview=True) + + # === Setup the main runner === + _args = copy.deepcopy(args) # Mutable record of args + if _args.is_multi_agent: + from minimax.runners_ma import ExperimentRunner + else: + from minimax.runners import ExperimentRunner + + xp_runner = ExperimentRunner( + train_runner=_args.train_runner, + env_name=_args.env_name, + agent_rl_algo=_args.agent_rl_algo, + student_model_name=_args.student_model_name, + student_critic_model_name=_args.student_critic_model_name, + student_agent_kind=_args.student_agent_kind, + teacher_model_name=_args.teacher_model_name, + train_runner_kwargs=_args.train_runner_args, + env_kwargs=_args.env_args, + ued_env_kwargs=_args.ued_env_args, + student_rl_kwargs=_args.student_rl_args, + teacher_rl_kwargs=_args.teacher_rl_args, + student_model_kwargs=_args.student_model_args, + teacher_model_kwargs=_args.teacher_model_args, + eval_kwargs=_args.eval_args, + eval_env_kwargs=_args.eval_env_args, + n_devices=_args.n_devices, + shaped_reward_steps=_args.n_shaped_reward_steps, + shaped_reward_n_updates=_args.n_shaped_reward_updates, + xpid=args.xpid + ) + + # === Configure logging === + # Set up wandb + wandb_args = args.wandb_args + if wandb_args.base_url: + os.environ["WANDB_BASE_URL"] = wandb_args.base_url + # if wandb_args.api_key: + # os.environ["WANDB_API_KEY"] = wandb_args.api_key + if wandb_args.base_url: # and wandb_args.api_key: + os.environ["WANDB_CACHE_DIR"] = '~/.cache/wandb' + wandb.init(project=wandb_args.project, + entity=wandb_args.entity, + config=args, + name=args.xpid, + group=wandb_args.group, + mode=wandb_args.mode + ) + callback = wandb.log + else: + callback = None + + logger = Logger( + log_dir=args.log_dir, + xpid=args.xpid, + xp_args=args, + callback=callback, + from_last_checkpoint=args.from_last_checkpoint, + verbose=args.verbose) + + # === Start training === + rng = jax.random.PRNGKey(args.seed) + xp_runner.train( + rng=rng, + n_total_updates=args.n_total_updates, + logger=logger, + log_interval=args.log_interval, + test_interval=args.test_interval, + checkpoint_interval=args.checkpoint_interval, + archive_interval=args.archive_interval, + archive_init_checkpoint=args.archive_init_checkpoint, + from_last_checkpoint=args.from_last_checkpoint) diff --git a/src/minimax/util/__init__.py b/src/minimax/util/__init__.py new file mode 100644 index 0000000..915c09b --- /dev/null +++ b/src/minimax/util/__init__.py @@ -0,0 +1,9 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .dotdict import * \ No newline at end of file diff --git a/src/minimax/util/args.py b/src/minimax/util/args.py new file mode 100644 index 0000000..a4ad4c2 --- /dev/null +++ b/src/minimax/util/args.py @@ -0,0 +1,20 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import argparse + + +def str2bool(v): + if isinstance(v, bool): + return v + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') \ No newline at end of file diff --git a/src/minimax/util/checkpoint.py b/src/minimax/util/checkpoint.py new file mode 100644 index 0000000..717d459 --- /dev/null +++ b/src/minimax/util/checkpoint.py @@ -0,0 +1,74 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import os +import shutil +import pickle +import json +from pathlib import Path + +from .dotdict import DefaultDotDict + + +def save_pkl_object(obj, path): + """Helper to store pickle objects.""" + output_file = Path(path) + output_file.parent.mkdir(exist_ok=True, parents=True) + + with open(path, "wb") as output: + # Overwrites any existing file. + pickle.dump(obj, output, pickle.HIGHEST_PROTOCOL) + + print(f"Stored checkpoint at {path}.") + + +def load_pkl_object(path: str): + """Helper to reload pickle objects.""" + with open(path, "rb") as input: + obj = pickle.load(input) + print(f"Loaded checkpoint from {path}.") + return obj + + +def safe_checkpoint( + state, + dir_path, + name, + index=None, + archive_interval=None): + savename = f'{name}.pkl' + tmp_savepath = f'{name}_tmp.pkl' + + save_path = os.path.join(dir_path, savename) + tmp_savepath = os.path.join(dir_path, tmp_savepath) + + save_pkl_object(state, tmp_savepath) + + # Rename + os.replace(tmp_savepath, save_path) + + # Archive if needed + if index is not None and archive_interval is not None and archive_interval > 0: + if index % archive_interval == 0: + archive_path = os.path.join(dir_path, f'{name}_{index}.pkl') + shutil.copy(save_path, archive_path) + + +def load_config(path: str): + with open(path) as meta_file: + _config = json.load(meta_file)['config'] + + config = {} + for k,v in _config.items(): + if isinstance(v, dict): + v = DefaultDotDict(v) + config[k] = v + + return DefaultDotDict(config) + + diff --git a/src/minimax/util/dotdict.py b/src/minimax/util/dotdict.py new file mode 100644 index 0000000..11cea5c --- /dev/null +++ b/src/minimax/util/dotdict.py @@ -0,0 +1,68 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import copy + + +class DotDict(dict): + __getattr__ = dict.__getitem__ + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + def __init__(self, dct): + for key, value in dct.items(): + if hasattr(value, 'keys'): + value = DotDict(value) + self[key] = value + + def __getstate__(self): + return self + + def __setstate__(self, state): + self.update(state) + self.__dict__ = self + + def __deepcopy__(self, memo): + return DotDict(copy.deepcopy(dict(self))) + + +class DefaultDotDict(dict): + __delattr__ = dict.__delitem__ + + def __init__(self, dct, default=None): + super().__init__(dct) + self._default = default + + def __getstate__(self): + return (self, self._default) + + def __setstate__(self, state): + self.update(state[0]) + self.__dict__ = self + self._default = state[1] + + def __setattr__(self, name, value): + if name == '_default': + super().__setattr__('_default', value) + else: + self.__setitem__(name, value) + + def __deepcopy__(self, memo): + return DefaultDotDict( + copy.deepcopy(dict(self)), + default=self._default + ) + + def __getattr__(self, name): + if name == '_default': + return self._default + else: + try: + return self.__getitem__(name) + except: + return self._default \ No newline at end of file diff --git a/src/minimax/util/graph.py b/src/minimax/util/graph.py new file mode 100644 index 0000000..34fcb35 --- /dev/null +++ b/src/minimax/util/graph.py @@ -0,0 +1,263 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np + + +@partial(jax.jit, static_argnums=(1,)) +def apsp(A, n=None): + """ + Compute APSP for adjacency matrix A + using Seidel's algorithm. + """ + if n is None: + n = A.shape[0] + assert(n == A.shape[0]), 'n must equal dim of A.' + + n_steps = int(np.ceil(np.log(n)/np.log(2))) + A_cache = jnp.zeros((n_steps, n, n), dtype=jnp.uint32) + steps_to_reduce = jnp.array(1, dtype=jnp.int32) + + def _scan_fwd_step(carry, step): + i = step + A, A_cache, steps_to_reduce = carry + A_cache = A_cache.at[i].set(A) + + Z = A@A + B = jnp.logical_or( + A == 1, + Z > 0 + ).astype(jnp.uint32) \ + .at[jnp.diag_indices(n)].set(0) + A = B + + complete = B.sum() - jnp.diagonal(B).sum() == n*(n-1) + steps_to_reduce += ~complete + + return (A, A_cache, steps_to_reduce), None + + (B, A_cache, steps_to_reduce), _ = jax.lax.scan( + _scan_fwd_step, + (A, A_cache, 1), + jnp.arange(n_steps), + length=n_steps + ) + + D = 2*B - A_cache[steps_to_reduce-1] + + def _scan_bkwd_step(carry, step): + i = step + (T, A_cache,steps_to_reduce) = carry + + A = A_cache[steps_to_reduce - i - 1] + X = T@A + + thresh = T*(jnp.tile(A.sum(0, keepdims=True), (n, 1))) + D = 2*T*(X >= thresh) + (2*T - 1)*(X < thresh) + T = D*(i < steps_to_reduce) + T*(i >= steps_to_reduce) + + return (T, A_cache, steps_to_reduce), None + + (D, _, _), _ = jax.lax.scan( + _scan_bkwd_step, + (D, A_cache, steps_to_reduce), + jnp.arange(1, n_steps), + length=n_steps-1 + ) + + return D + + +@jax.jit +def grid_to_graph(grid): + """ + Transform a binary grid (True == wall) into a + graph, where walls are all connected to a default node. + """ + h, w = grid.shape + nodes = grid.flatten() + n = len(nodes) + A = jnp.zeros((n,n), dtype=jnp.uint32) + + all_idx = jnp.arange(n) + dum_neighbor_idx = jnp.argmax(~nodes) + dum_neighbor_mask = jnp.zeros(n, dtype=jnp.bool_) + dum_neighbor_mask = \ + dum_neighbor_mask.at[dum_neighbor_idx].set(True) + + def _get_neigbors(idx): + # Return length n boolean mask of neighbors + # We then vmap this function over all n + r = idx + 1 + l = idx - 1 + t = idx - w + b = idx + w + + l_border = jnp.logical_or( + idx % w == 0, + nodes[l] + ) + r_border = jnp.logical_or( + r % w == 0, + nodes[r] + ) + t_border = jnp.logical_or( + idx // w == 0, + nodes[t], + ) + b_border = jnp.logical_or( + idx // w == h - 1, + nodes[b] + ) + + l_ignore = jnp.logical_or( + l_border, + nodes[idx] + ) + r_ignore = jnp.logical_or( + r_border, + nodes[idx] + ) + t_ignore = jnp.logical_or( + t_border, + nodes[idx] + ) + b_ignore = jnp.logical_or( + b_border, + nodes[idx] + ) + + left = l*(1-l_ignore) + idx*(l_ignore) + right = r*(1-r_ignore) + idx*(r_ignore) + top = t*(1-t_ignore) + idx*(t_ignore) + bottom = b*(1-b_ignore) + idx*(b_ignore) + + neighbor_mask = jnp.zeros(n, dtype=jnp.bool_) + neighbor_mask = neighbor_mask.at[jnp.array([left, right, top, bottom])].set(True) + + neighbor_mask = (1-nodes[idx])*neighbor_mask + nodes[idx]*dum_neighbor_mask + + neighbor_mask = neighbor_mask.at[idx].set(False) + + return neighbor_mask + + A = jax.vmap(_get_neigbors)(all_idx).astype(dtype=jnp.uint32) + A = jnp.maximum(A, A.transpose()) + + return A + + +NEIGHBOR_OFFSETS = jnp.array([ + [1,0], # right + [0,1], # bottom + [-1,0], # left + [0,-1], # top + [0,0] # self +], dtype=jnp.int32) + + +@jax.jit +def component_mask_with_pos(grid, pos_a): + """ + Return a mask set to True in all cells that are + a part of the connected component containing pos_a. + """ + h,w = grid.shape + visited_mask = grid + + pos = pos_a + visited_mask = visited_mask.at[ + pos[1],pos[0] + ].set(True) + vstack = jnp.zeros((h*w, 2), dtype=jnp.uint32) + vstack = vstack.at[:2].set(pos) + vstack_size = 2 + + def _scan_dfs(carry, step): + (visited_mask, vstack, vstack_size) = carry + + pos = vstack[vstack_size-1] + + neighbors = \ + jnp.minimum( + jnp.maximum( + pos + NEIGHBOR_OFFSETS, 0 + ), jnp.array([[h,w]]) + ).astype(jnp.uint32) + + neighbors_mask = visited_mask.at[ + neighbors[:,1],neighbors[:,0] + ].get() + n_neighbor_visited = neighbors_mask[:4].sum() + all_visited = n_neighbor_visited == 4 + all_visited_post = n_neighbor_visited >= 3 + neighbors_mask = neighbors_mask.at[-1].set(~all_visited) + + next_neighbor_idx = jnp.argmax(~neighbors_mask) + next_neighbor = neighbors[next_neighbor_idx] + + visited_mask = visited_mask.at[ + next_neighbor[1],next_neighbor[0] + ].set(True) + + vstack_size -= all_visited_post + + vstack = vstack.at[vstack_size].set(next_neighbor) + vstack_size += ~all_visited + + pos = next_neighbor + + return (visited_mask, vstack, vstack_size), None + + max_n_steps = 2*h*w + (visited_mask, vstack, vstack_size), _ = jax.lax.scan( + _scan_dfs, + (visited_mask, vstack, vstack_size), + jnp.arange(max_n_steps), + length=max_n_steps + ) + + visited_mask = visited_mask ^ grid + return visited_mask + + +@jax.jit +def shortest_path_len(grid, pos_a, pos_b,ignore_value=-1): + grid = ~component_mask_with_pos(grid, pos_a) + + A = grid_to_graph(grid) + D = apsp(A, n=A.shape[0]) + + if len(pos_b.shape) == 2: # batch eval + return jax.vmap(_shortest_path_len, in_axes=(None, None, 0, None))( + grid, pos_a, pos_b, D, ignore_value + ) + else: + return _shortest_path_len(grid, pos_a, pos_b, D, ignore_value) + + +@jax.jit +def _shortest_path_len(grid, pos_a, pos_b, D, ignore_value): + h,w = grid.shape + + a_idx = pos_a[1]*w + pos_a[0] + b_idx = pos_b[1]*w + pos_b[0] + d = D[a_idx][b_idx] + + mhttn_d = jnp.sum(jnp.abs(jnp.maximum(pos_a,pos_b)- jnp.minimum(pos_a,pos_b))) + + impossible = jnp.logical_and( + d == 1, + mhttn_d > 1 + ) + + return d*(1-impossible) diff --git a/src/minimax/util/loggers.py b/src/minimax/util/loggers.py new file mode 100644 index 0000000..8f5d646 --- /dev/null +++ b/src/minimax/util/loggers.py @@ -0,0 +1,291 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This file is modified from +https://github.com/openai/baselines + +Licensed under the MIT License; +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://opensource.org/license/mit/ +""" + + +import os +import sys +import json +import csv +import time +import datetime +import copy +import logging +import git +from typing import Dict + +import minimax.util.checkpoint as _checkpoint_util + + +class KVWriter(object): + def writekvs(self, kvs): + raise NotImplementedError + + +class SeqWriter(object): + def writeseq(self, seq): + raise NotImplementedError + + +class HumanOutputFormat(KVWriter, SeqWriter): + def __init__(self, filename_or_file): + if isinstance(filename_or_file, str): + self.file = open(filename_or_file, 'wt') + self.own_file = True + else: + assert hasattr( + filename_or_file, 'read'), 'expected file or str, got %s' % filename_or_file + self.file = filename_or_file + self.own_file = False + + def writekvs(self, kvs): + # Create strings for printing + key2str = {} + for (key, val) in sorted(kvs.items()): + if hasattr(val, '__float__'): + valstr = '%-8.3g' % val + else: + valstr = str(val) + key2str[self._truncate(key)] = self._truncate(valstr) + + # Find max widths + if len(key2str) == 0: + print('WARNING: tried to write empty key-value dict') + return + else: + keywidth = max(map(len, key2str.keys())) + valwidth = max(map(len, key2str.values())) + + # Write out the data + dashes = '-' * (keywidth + valwidth + 7) + lines = [dashes] + for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): + lines.append('| %s%s | %s%s |' % ( + key, + ' ' * (keywidth - len(key)), + val, + ' ' * (valwidth - len(val)), + )) + lines.append(dashes) + self.file.write('\n'.join(lines) + '\n') + + # Flush the output to the file + self.file.flush() + + def _truncate(self, s): + maxlen = 64 + return s[:maxlen-3] + '...' if len(s) > maxlen else s + + def writeseq(self, seq): + seq = list(seq) + for (i, elem) in enumerate(seq): + self.file.write(elem) + if i < len(seq) - 1: # add space unless this is the last one + self.file.write(' ') + self.file.write('\n') + self.file.flush() + + def close(self): + if self.own_file: + self.file.close() + + +def gather_metadata() -> Dict: + date_start = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f") + # Gathering git metadata. + try: + import git + + try: + repo = git.Repo(search_parent_directories=True) + git_sha = repo.commit().hexsha + git_data = dict( + commit=git_sha, + branch=None if repo.head.is_detached else repo.active_branch.name, + is_dirty=repo.is_dirty(), + path=repo.git_dir, + ) + except git.InvalidGitRepositoryError: + git_data = None + except ImportError: + git_data = None + + # Gathering slurm metadata. + if "SLURM_JOB_ID" in os.environ: + slurm_env_keys = [k for k in os.environ if k.startswith("SLURM")] + slurm_data = {} + for k in slurm_env_keys: + d_key = k.replace("SLURM_", "").replace("SLURMD_", "").lower() + slurm_data[d_key] = os.environ[k] + else: + slurm_data = None + return dict( + date_start=date_start, + date_end=None, + successful=False, + git=git_data, + slurm=slurm_data, + env=os.environ.copy(), + ) + + +class Logger: + def __init__( + self, + log_dir='~/logs/minimax', + xpid=None, + xp_args=None, + callback=None, + from_last_checkpoint=False, + verbose=False): + # Set up checkpoint meta + self.verbose = verbose + if self.verbose: + self._stdout = HumanOutputFormat(sys.stdout) + + self._callback = callback + + formatter = logging.Formatter("%(message)s") + self._logger = logging.getLogger("logs/out") + shandle = logging.StreamHandler() + shandle.setFormatter(formatter) + self._logger.addHandler(shandle) + self._logger.setLevel(logging.INFO) + + # Set up main paths for logs and checkpoints + self.paths = {} + log_dir_path = os.path.expandvars(os.path.expanduser(log_dir)) + xpid_dir_path = os.path.join(log_dir_path, xpid) + if not xpid: + xpid = "{proc}_{unixtime}".format( + proc=os.getpid(), unixtime=int(time.time()) + ) + if not os.path.exists(xpid_dir_path): + self._logger.info("Creating log directory: %s", xpid_dir_path) + os.makedirs(xpid_dir_path, exist_ok=True) + self.paths['log_dir'] = log_dir_path + self.paths['xpid_dir'] = xpid_dir_path + self.paths['checkpoint'] = os.path.join( + xpid_dir_path, 'checkpoint.pkl') + + # Create logs.csv file + logs_csv_path = os.path.join(xpid_dir_path, 'logs.csv') + self.paths['log_csv'] = logs_csv_path + self._last_n_logged_lines = 0 + self._last_logged_tick = self._get_last_logged_tick() + + self.append_to_existing_logs = \ + self._last_logged_tick >= 0 \ + and from_last_checkpoint \ + and os.path.exists(self.paths['checkpoint']) + log_mode = "a" if self.append_to_existing_logs else "w+" + self._logfile = open(logs_csv_path, log_mode) + self._logwriter = None + + # Create meta file + if xp_args is not None: + meta_path = os.path.join(xpid_dir_path, 'meta.json') + + meta = gather_metadata() + meta['config'] = dict(xp_args) + meta['xpid'] = xpid + + self._save_metadata(meta_path, meta) + + def _save_metadata(self, meta_path, meta): + with open(meta_path, "w") as jsonfile: + json.dump(meta, jsonfile, indent=4, sort_keys=True) + + def _get_last_logged_tick(self): + last_tick = -1 + logs_csv_path = self.paths['log_csv'] + if os.path.exists(logs_csv_path): + with open(logs_csv_path, "r") as csvfile: + reader = csv.reader(csvfile) + try: + lines = list(reader) + except: + return -1 + # Need at least two lines in order to read the last tick: + # the first is the csv header and the second is the first line + # of data. + if len(lines) > 1: + self._last_n_logged_lines = len(lines) + try: + last_tick = int(lines[-1][0]) + except: + last_tick = -1 + + return last_tick + + def log(self, stats, _tick, ignore_val=None): + if ignore_val is not None: + stats = {k: v if v != ignore_val else None for k, v in stats.items()} + + _stats = { + '_tick': _tick, + '_time': time.time() + } + _stats.update(stats) + stats = _stats + + if self._logwriter is None: + fieldnames = list(stats.keys()) + self._logwriter = csv.DictWriter( + self._logfile, fieldnames=fieldnames) + + if _tick > self._last_logged_tick \ + or not self.append_to_existing_logs: + if self._last_n_logged_lines == 0: + fieldnames = list(stats.keys()) + self._logfile.write("# %s\n" % ",".join(fieldnames)) + self._logfile.flush() + self._last_n_logged_lines = 1 + + self._logwriter.writerow(stats) + self._logfile.flush() + + if self._callback is not None: + self._callback(stats) + + if self.verbose: + self._stdout.writekvs(stats) + + @property + def checkpoint_path(self): + return self.paths['checkpoint'] + + def checkpoint( + self, + runner_state, + name='checkpoint', + index=None, + archive_interval=None): + _checkpoint_util.safe_checkpoint( + runner_state, + self.paths['xpid_dir'], + name, + index, + archive_interval + ) + + def load_last_checkpoint_state(self): + checkpoint_path = \ + os.path.join(self.paths['xpid_dir'], f'checkpoint.pkl') + + if os.path.exists(checkpoint_path): + self._logger.info( + f'Loading previous checkpoint from {checkpoint_path}...') + return _checkpoint_util.load_pkl_object(checkpoint_path) + else: + return None diff --git a/src/minimax/util/parsnip.py b/src/minimax/util/parsnip.py new file mode 100644 index 0000000..8c1f330 --- /dev/null +++ b/src/minimax/util/parsnip.py @@ -0,0 +1,329 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from collections import defaultdict +import argparse +import sys +import re +import pprint + +from minimax.util import DefaultDotDict, DotDict + + +def append_subparser_prefix(prefix, func): + def prefixed_add_argument(*args, **kwargs): + if len(args) > 0: + name = args[0] + if name.startswith('--'): + name = f'--{prefix}_{name[2:]}' + args = (name,) + args[1:] + + return func(*args, **kwargs) + + return prefixed_add_argument + + +def ensure_args_suffix(name): + if not name.endswith('args'): + name = f'{name}_args' + + return name + + +def get_all_cmd_arg_names(): + cmd_args = [s.removeprefix('--') for s in sys.argv if s.startswith('--')] + arg_names = [x.split('=')[0] for x in cmd_args] + + return set(arg_names) + + +def get_argument_kwargs(subparser): + kwargs = [] + skip_list = ['-h', '--help'] + for k, info in subparser.__dict__['_option_string_actions'].items(): + if k in skip_list: + continue + + info_dict = info.__dict__ + kwargs_dict = dict( + # option_strings=info_dict['option_strings'], + name=info_dict['dest'], + const=info_dict['const'], + default=info_dict['default'], + type=info_dict['type'], + choices=info_dict['choices'], + required=info_dict['required'] + ) + + nargs = info_dict['nargs'] + if nargs == '?' or (nargs is not None and int(nargs) > 0): + kwargs_dict.update(dict(nargs=info_dict['nargs'])) + + kwargs.append(kwargs_dict) + + return kwargs + + +class Parsnip: + """ + Wraps a collection of argparse instances + to enable convenient grouping of arguments and + access via a DotDict-style interface. + """ + def __init__(self, description=None): + self._base_parser = \ + argparse.ArgumentParser(description=description) + self._subparsers = {} + self._prefixes = [] + self._dependencies = defaultdict() + self._dests = defaultdict() + self._dependent_args = set() + + def add_subparser( + self, + name, + prefix=None, + dest=None, + depends_on=None, + dependency=None, + is_individual_arg=False, + description=None + ): + if not is_individual_arg: + name = ensure_args_suffix(name) + + assert name not in self._subparsers, \ + f'Multiple subparsers named {name} detected.' + + if dependency is not None: + if depends_on is not None: + depends_on = ensure_args_suffix(depends_on) + + assert depends_on in self._subparsers, \ + f'Missing subparse {depends_on} must be added before dependent {name}.' + + assert isinstance(dependency, dict), \ + f'Subparser dependencies must be specified as dicts.' + + self._dependencies[name] = (depends_on, dependency) + + if dest is not None: + dest = ensure_args_suffix(dest) + assert dest in self._subparsers, \ + f"Missing dest {dest} must be specified before source {name}." + + subparser = argparse.ArgumentParser( + description=description, + allow_abbrev=False) + if prefix is not None: + subparser.add_argument = append_subparser_prefix( + prefix, subparser.add_argument, + ) + + self._subparsers[name] = subparser + self._prefixes.append(prefix) + self._dests[name] = dest + + return subparser + + def add_dependent_argument( + self, + *args, + **kwargs,): + + assert 'dependency' in kwargs, \ + 'Must specify dependency in kwargs.' + dependency = kwargs.pop('dependency') + + name = args[0].removeprefix('--') + + prefix = kwargs.pop('prefix', None) + dest = kwargs.pop('dest', None) + + subparser = self.add_subparser( + name, + prefix=prefix, + dependency=dependency, + dest=dest, + is_individual_arg=True, + description=kwargs.pop('description', '') + ) + subparser.add_argument(*args, **kwargs) + + self._dependent_args.add(name) + + def copy_arguments( + self, + src, + dest, + arg_prefix=None + ): + all_subparsers = [k for k in self._subparsers.keys()] + src_name = f'{src}_args' + src_subparser = self._subparsers[src_name] + src_idx = all_subparsers.index(src_name) + src_prefix = self._prefixes[src_idx] + + dest_name = f'{dest}_args' + dest_subparser = self._subparsers[dest_name] + dest_idx = all_subparsers.index(dest_name) + + arg_prefix = f'{arg_prefix}_' if arg_prefix else '' + + argument_kwargs = get_argument_kwargs(src_subparser) + for kwargs in argument_kwargs: + name = kwargs.pop('name') + flag = f"--{arg_prefix}{name.removeprefix(f'{src_prefix}_')}" + + dest_subparser.add_argument( + flag, + **kwargs + ) + + def parse_args(self, preview=False): + cmd_arg_names, args, arg_data, _ = self._parse_cmd_line_flags() + + for name in cmd_arg_names: + if name not in arg_data: + raise ValueError(f'Unknown argument {name}.') + + if preview: + print('πŸ₯• Parsnip harvested the following arguments:') + pp = pprint.PrettyPrinter(indent=4) + pp.pprint(args) + + return args + + def parse_cmd_line_flags(self, as_grid_json=False): + _, _, arg_data, _ = self._parse_cmd_line_flags() + + if as_grid_json: + arg_data = {k:[v] for k,v in arg_data.items() if v is not None} + + return arg_data + + def _parse_cmd_line_flags(self, keep_structure=False): + cmd_arg_names = get_all_cmd_arg_names() + + arg_data = {} + argname2keypath = {} + + args = DefaultDotDict(vars(self._base_parser.parse_known_args()[0])) + for k in args: + arg_data[k] = args[k] + argname2keypath[k] = [k] + + for i, (name, subparser) in enumerate(self._subparsers.items()): + if name in self._dependencies: + depends_on, dependencies = self._dependencies[name] + if depends_on is None: + dsubargs = arg_data + else: + dsubargs = arg_data[depends_on] + skip_subparser = False + for dk, dv in dependencies.items(): + if not isinstance(dv, (tuple, list)): + dv = [dv] + + for v in dv: + if isinstance(v, str): + match_regex = f"^{v.replace('*', '.*')}$" + match = re.match(match_regex, dsubargs[dk]) is not None + else: + match = dsubargs[dk] == v + if match: + break + + if not match: + skip_subparser = True + break + + if skip_subparser: + continue + + prefix = self._prefixes[i] + subargs, _ = subparser.parse_known_args() + + subargs = vars(subargs) + for k in subargs: + arg_data[k] = subargs[k] + + subargs = DotDict( + {k.removeprefix(f'{prefix}_'):v for k,v in subargs.items()} + ) + + # Check for flatten and merge conditions + dest = self._dests.get(name) + if name in self._dependent_args: + if dest is None: + args[name] = subargs[name] + else: + args[dest].update({name: subargs[name]}) + else: + if dest is None: + args[name] = subargs + else: + args[dest].update(subargs) + + for k in subargs: + if prefix is not None: + argname = f'{prefix}_{k}' + else: + argname = k + if dest is None: + sp_name = name + else: + sp_name = dest + if sp_name == k: + argname2keypath[argname] = [k] + else: + argname2keypath[argname] = [sp_name, k] + + return cmd_arg_names, args, arg_data, argname2keypath + + @property + def argname2keypath(self): + args = DefaultDotDict(vars(self._base_parser.parse_known_args()[0])) + argname2keypath = {} + for k in args: + argname2keypath[k] = args[k] + + for i, (name, subparser) in enumerate(self._subparsers.items()): + if name in self._dependencies: + depends_on, dependencies = self._dependencies[name] + if depends_on is None: + dsubargs = arg_data + else: + dsubargs = arg_data[depends_on] + skip_subparser = False + for dk, dv in dependencies.items(): + if not isinstance(dv, (tuple, list)): + dv = [dv] + + for v in dv: + if isinstance(v, str): + match_regex = f"^{v.replace('*', '.*')}$" + match = re.match(match_regex, dsubargs[dk]) is not None + else: + match = dsubargs[dk] == v + if match: + break + + if not match: + skip_subparser = True + break + + if skip_subparser: + continue + + prefix = self._prefixes[i] + subargs, _ = subparser.parse_known_args() + + def __getattr__(self, attr): + # Default missing attr to _base_argparse instance + return getattr(self._base_parser, attr) \ No newline at end of file diff --git a/src/minimax/util/pytree.py b/src/minimax/util/pytree.py new file mode 100644 index 0000000..f25f572 --- /dev/null +++ b/src/minimax/util/pytree.py @@ -0,0 +1,37 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +import collections + +import jax + + +def pytree_set_array_at(pytree, i, value): + return jax.tree_util.tree_map(lambda x,y: x.at[i].set(y), pytree, value) + +def pytree_set_struct_at(pytree, i, value): + return jax.tree_util.tree_map(lambda x,y: x.at[i].set(y), pytree, value) + +def pytree_at(pytree, start, end=None): + return jax.tree_util.tree_map(lambda x: x.at[start:end].get(), pytree) + +def pytree_select(pred, on_true, on_false): + vselect = jax.vmap(jax.lax.select, in_axes=(0, 0, 0)) + return jax.tree_util.tree_map(lambda x,y: vselect(pred, x, y), on_true, on_false) + +def pytree_expand_batch_dim(pytree, batch_shape, n_batch_axes=2): + """ + Expands a single batch dimension into a multi-dim batch shape + """ + return jax.tree_util.tree_map(lambda x: x.reshape(*batch_shape, *x.shape[n_batch_axes:]), pytree) + +def pytree_transform(pytree, transform): + return jax.tree_util.tree_map(lambda x: transform(x), pytree) + +def pytree_merge(dst, src, start_idx, src_len): + return jax.tree_map(lambda x,y: x.at[start_idx:start_idx+src_len].set(y.at[:src_len].get()), dst, src) \ No newline at end of file diff --git a/src/minimax/util/rl/__init__.py b/src/minimax/util/rl/__init__.py new file mode 100644 index 0000000..11b8e73 --- /dev/null +++ b/src/minimax/util/rl/__init__.py @@ -0,0 +1,16 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from .training import VmapTrainState, VmapMAPPOTrainState +from .agent_pop import AgentPop +from .agent_pop_heterogenous import AgentPopHeterogenous +from .rolling_stats import RollingStats +from .rollout_storage import RolloutStorage +from .rollout_storage_seperate import RolloutStorageSeperate +from .ued_scores import * +from .plr import PLRManager, PopPLRManager diff --git a/src/minimax/util/rl/agent_pop.py b/src/minimax/util/rl/agent_pop.py new file mode 100644 index 0000000..0a8ee1a --- /dev/null +++ b/src/minimax/util/rl/agent_pop.py @@ -0,0 +1,119 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + +import numpy as np +import jax +import jax.numpy as jnp + + +class AgentPop: + """ + Manages multiple agents of the same architecture + """ + + def __init__( + self, + agent, + n_agents): + """ + Maintains a set of model parameters. + """ + self.agent = agent + self.n_agents = n_agents + + def _reshape_to_pop(self, x): + return jax.tree_map( + lambda x: jnp.reshape(x, newshape=(self.n_agents, x.shape[0]//self.n_agents, *x.shape[1:])), x) + + def _flatten(self, x): + return jax.tree_map(lambda x: jnp.reshape(x, newshape=(self.n_agents*x.shape[1], -1)).squeeze(), x) + + def init_params(self, rng, obs): + if self.agent.is_recurrent: + # Make time first dim + obs = jax.tree_map(lambda x: x[jnp.newaxis, :], obs) + + vrngs = jax.random.split(rng, self.n_agents) + return jax.vmap( + self.agent.init_params, + in_axes=(0, None) + )(vrngs, obs) + + @partial(jax.jit, static_argnums=(0,)) + def init_carry(self, rng, obs): + if hasattr(self.agent, "actor") and self.agent.actor.conv_encoder: + agent_batch_dim = jax.tree_util.tree_leaves(obs)[0].shape[:-3] + elif not hasattr(self.agent, "actor"): + agent_batch_dim = jax.tree_util.tree_leaves(obs)[0].shape[:-3] + else: # Linear obs + agent_batch_dim = jax.tree_util.tree_leaves(obs)[0].shape[:-1] + return self.agent.init_carry(rng=rng, batch_dims=agent_batch_dim) + + @partial(jax.jit, static_argnums=(0,)) + def act(self, params, obs, carry, reset=None): + # If recurrent, add time axis to support scanned rollouts + if self.agent.is_recurrent: + # Add time dim after agent dim + obs = jax.tree_map(lambda x: x[:, jnp.newaxis, :], obs) + + if reset is None: + agent_batch_dim = jax.tree_util.tree_leaves(obs)[0].shape[2] + reset = jnp.zeros( + (self.n_agents, 1, agent_batch_dim), dtype=jnp.bool_) + else: + reset = reset[:, jnp.newaxis, :] + + value, pi_params, next_carry = jax.vmap( + self.agent.act)(params, obs, carry, reset) + + if self.agent.is_recurrent: # Remove time dim + if value is not None: + value = value.squeeze(1) + pi_params = jax.tree_map(lambda x: x.squeeze(1), pi_params) + + return value, pi_params, next_carry + + def get_action_dist(self, dist_params, dtype=jnp.uint8): + return self.agent.get_action_dist(dist_params, dtype=dtype) + + @partial(jax.jit, static_argnums=(0,)) + def get_value(self, params, obs, carry, reset=None): + if self.agent.is_recurrent: + # Add time dim after agent dim + obs = jax.tree_map(lambda x: x[:, jnp.newaxis, :], obs) + + if reset is None: + agent_batch_dim = jax.tree_util.tree_leaves(obs)[0].shape[2] + reset = jnp.zeros( + (self.n_agents, 1, agent_batch_dim), dtype=jnp.bool_) + else: + reset = reset[:, jnp.newaxis, :] + + value, next_carry = jax.vmap( + self.agent.get_value)(params, obs, carry, reset) + + if self.agent.is_recurrent: # Remove time dim + value = value.squeeze(1) + + if value.shape[-1] == 1: + value = value.squeeze(-1) + return value, next_carry + + @partial(jax.jit, static_argnums=(0, 4, 5)) + def update(self, rng, train_state, batch, prefix_steps=0, fake=False): + if fake: + return train_state, jax.vmap(lambda *_: self.agent.get_empty_update_stats())(np.arange(self.n_agents)) + + rng, *vrngs = jax.random.split(rng, self.n_agents+1) + vrngs = jnp.array(vrngs) + + new_train_state, stats = jax.vmap( + self.agent.update)(vrngs, train_state, batch) + return new_train_state, stats diff --git a/src/minimax/util/rl/agent_pop_heterogenous.py b/src/minimax/util/rl/agent_pop_heterogenous.py new file mode 100644 index 0000000..39381ca --- /dev/null +++ b/src/minimax/util/rl/agent_pop_heterogenous.py @@ -0,0 +1,166 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + +import numpy as np +import jax +import jax.numpy as jnp + + +class AgentPopHeterogenous: + """ + Manages multiple agents with no assumption regarding the architecture + """ + + def __init__( + self, + agent_0, + agent_1, + n_agents): + """ + Maintains a set of model parameters. + """ + self.agent_0 = agent_0 + self.agent_1 = agent_1 + self.n_agents = n_agents + + def _reshape_to_pop(self, x): + return jax.tree_map( + lambda x: jnp.reshape(x, newshape=(self.n_agents, x.shape[0]//self.n_agents, *x.shape[1:])), x) + + def _flatten(self, x): + return jax.tree_map(lambda x: jnp.reshape(x, newshape=(self.n_agents*x.shape[1], -1)).squeeze(), x) + + def init_params(self, rng, obs): + if self.agent.is_recurrent: + # Make time first dim + obs = jax.tree_map(lambda x: x[jnp.newaxis, :], obs) + + vrngs = jax.random.split(rng, self.n_agents) + return jax.vmap( + self.agent.init_params, + in_axes=(0, None) + )(vrngs, obs) + + # @partial(jax.jit, static_argnums=(0,)) + # def init_carry(self, rng, obs): + # if self.agent.actor.conv_encoder: + # agent_batch_dim = jax.tree_util.tree_leaves(obs)[0].shape[:-3] + # else: # Linear obs + # agent_batch_dim = jax.tree_util.tree_leaves(obs)[0].shape[:-1] + # return self.agent.init_carry(rng=rng, batch_dims=agent_batch_dim) + + @partial(jax.jit, static_argnums=(0,)) + def init_carry_agent_0(self, rng, obs): + if self.agent_0.actor.conv_encoder: + agent_batch_dim = jax.tree_util.tree_leaves(obs)[0].shape[:-3] + else: # Linear obs + agent_batch_dim = jax.tree_util.tree_leaves(obs)[0].shape[:-1] + return self.agent_0.init_carry(rng=rng, batch_dims=agent_batch_dim) + + @partial(jax.jit, static_argnums=(0,)) + def init_carry_agent_1(self, rng, obs): + if self.agent_1.actor.conv_encoder: + agent_batch_dim = jax.tree_util.tree_leaves(obs)[0].shape[:-3] + else: # Linear obs + agent_batch_dim = jax.tree_util.tree_leaves(obs)[0].shape[:-1] + return self.agent_1.init_carry(rng=rng, batch_dims=agent_batch_dim) + + @partial(jax.jit, static_argnums=(0,)) + def act(self, params, obs, carry, reset=None): + # If recurrent, add time axis to support scanned rollouts + actor_0_params, actor_1_params = params + actor_0_carry, actor_1_carry = carry + + if self.agent_0.is_recurrent: + # Add time dim after agent dim + obs_0 = jax.tree_map( + lambda x: x[:, jnp.newaxis, :], obs['agent_0']) + + if reset is None: + agent_batch_dim = jax.tree_util.tree_leaves(obs_0)[0].shape[2] + reset = jnp.zeros( + (self.n_agents, 1, agent_batch_dim), dtype=jnp.bool_) + else: + reset = reset[:, jnp.newaxis, :] + else: + obs_0 = obs['agent_0'] + + if self.agent_1.is_recurrent: + # Add time dim after agent dim + obs_1 = jax.tree_map( + lambda x: x[:, jnp.newaxis, :], obs['agent_1']) + + if reset is None: + agent_batch_dim = jax.tree_util.tree_leaves(obs_1)[0].shape[2] + reset = jnp.zeros( + (self.n_agents, 1, agent_batch_dim), dtype=jnp.bool_) + else: + reset = reset[:, jnp.newaxis, :] + else: + obs_1 = obs['agent_1'] + + value_0, pi_params_0, next_0_carry = jax.vmap( + self.agent_0.act)(actor_0_params, obs_0, actor_0_carry, reset) + + value_1, pi_param_1, next_1_carry = jax.vmap( + self.agent_1.act)(actor_1_params, obs_1, actor_1_carry, reset) + + if self.agent_0.is_recurrent: # Remove time dim + if value_0 is not None: + value_0 = value_0.squeeze(1) + pi_params_0 = jax.tree_map(lambda x: x.squeeze(1), pi_params_0) + + if self.agent_1.is_recurrent: # Remove time dim + if value_1 is not None: + value_1 = value_1.squeeze(1) + pi_param_1 = jax.tree_map(lambda x: x.squeeze(1), pi_param_1) + + return value_0, value_1, pi_params_0, pi_param_1, next_0_carry, next_1_carry + + def get_action_0_dist(self, dist_params, dtype=jnp.uint8): + return self.agent_0.get_action_dist(dist_params, dtype=dtype) + + def get_action_1_dist(self, dist_params, dtype=jnp.uint8): + return self.agent_1.get_action_dist(dist_params, dtype=dtype) + + @partial(jax.jit, static_argnums=(0,)) + def get_value(self, params, obs, carry, reset=None): + if self.agent.is_recurrent: + # Add time dim after agent dim + obs = jax.tree_map(lambda x: x[:, jnp.newaxis, :], obs) + + if reset is None: + agent_batch_dim = jax.tree_util.tree_leaves(obs)[0].shape[2] + reset = jnp.zeros( + (self.n_agents, 1, agent_batch_dim), dtype=jnp.bool_) + else: + reset = reset[:, jnp.newaxis, :] + + value, next_carry = jax.vmap( + self.agent.get_value)(params, obs, carry, reset) + + if self.agent.is_recurrent: # Remove time dim + value = value.squeeze(1) + + if value.shape[-1] == 1: + value = value.squeeze(-1) + return value, next_carry + + @partial(jax.jit, static_argnums=(0, 4, 5)) + def update(self, rng, train_state, batch, prefix_steps=0, fake=False): + if fake: + return train_state, jax.vmap(lambda *_: self.agent.get_empty_update_stats())(np.arange(self.n_agents)) + + rng, *vrngs = jax.random.split(rng, self.n_agents+1) + vrngs = jnp.array(vrngs) + + new_train_state, stats = jax.vmap( + self.agent.update)(vrngs, train_state, batch) + return new_train_state, stats diff --git a/src/minimax/util/rl/hl_gauss_transform.py b/src/minimax/util/rl/hl_gauss_transform.py new file mode 100644 index 0000000..a44f57c --- /dev/null +++ b/src/minimax/util/rl/hl_gauss_transform.py @@ -0,0 +1,39 @@ +import chex +import jax +import jax.numpy as jnp +import jax.scipy.special + + +def hl_gauss_transform( + min_value: float, + max_value: float, + num_bins: int, + sigma: float, +): + support = jnp.linspace(min_value, max_value, num_bins+1, dtype=jnp.float32) + + def transform_to_probs(target: chex.Array) -> chex.Array: + cdf_evals = jax.scipy.special.erf((support-target)/(jnp.sqrt(2)*sigma)) + z = cdf_evals[-1] - cdf_evals[0] + bin_probs = cdf_evals[1:] - cdf_evals[:-1] + return bin_probs / z + + def transform_from_probs(probs: chex.Array) -> chex.Array: + centers = (support[:-1] + support[1:]) / 2 + return jnp.sum(probs * centers) + + return transform_to_probs, transform_from_probs + + +if __name__ == '__main__': + transform_to_probs, transform_from_probs = hl_gauss_transform( + min_value=0, + max_value=20.0, + num_bins=10, + sigma=0.1, + ) + + for r in [0, 3, 20]: + probs = transform_to_probs(jnp.array(r)) + print(f'Probs for {r}: {probs}') + print(f'Reconstructed from probs: {transform_from_probs(probs)}') diff --git a/src/minimax/util/rl/plr.py b/src/minimax/util/rl/plr.py new file mode 100644 index 0000000..ffbd9fc --- /dev/null +++ b/src/minimax/util/rl/plr.py @@ -0,0 +1,466 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + +import jax +import jax.numpy as jnp +from flax import struct +import chex +import numpy as np + +from .ued_scores import UEDScore + + +class PLRBuffer(struct.PyTreeNode): + levels: chex.Array + scores: chex.Array + ages: chex.Array + max_returns: chex.Array # for MaxMC + filled: chex.Array + filled_count: chex.Array + n_mutations: chex.Array + + ued_score: int = struct.field( + pytree_node=False, default=UEDScore.L1_VALUE_LOSS.value) + replay_prob: float = struct.field(pytree_node=False, default=0.5) + buffer_size: int = struct.field(pytree_node=False, default=100) + staleness_coef: float = struct.field(pytree_node=False, default=0.3) + temp: float = struct.field(pytree_node=False, default=1.0) + use_score_ranks: bool = struct.field(pytree_node=False, default=True) + min_fill_ratio: float = struct.field(pytree_node=False, default=0.5) + use_robust_plr: bool = struct.field(pytree_node=False, default=False) + use_parallel_eval: bool = struct.field(pytree_node=False, default=False) + + +class PLRManager: + def __init__( + self, + example_level, # Example env instance + ued_score, + replay_prob=0.5, + buffer_size=100, + staleness_coef=0.3, + temp=1.0, + min_fill_ratio=0.5, + use_score_ranks=True, + use_robust_plr=False, + use_parallel_eval=False, + comparator_fn=None, + n_devices=1): + + assert not (ued_score == UEDScore.MAX_MC and not use_score_ranks), \ + 'Cannot use proportional normalization with MaxMC, which can produce negative scores.' + + self.ued_score = ued_score + self.replay_prob = replay_prob + self.buffer_size = buffer_size + self.staleness_coef = staleness_coef + self.temp = temp + self.min_fill_ratio = min_fill_ratio + self.use_score_ranks = use_score_ranks + self.use_robust_plr = use_robust_plr + self.use_parallel_eval = use_parallel_eval + self.comparator_fn = comparator_fn + + self.n_devices = n_devices + + example_level = jax.tree_map(lambda x: jnp.array(x), example_level) + self.levels = jax.tree_map( + lambda x: ( + jnp.tile(jnp.zeros_like(x), (buffer_size,) + (1,)*(len(x.shape)-1))).reshape(buffer_size, *x.shape), + example_level) + + self.scores = jnp.full(buffer_size, -jnp.inf) + self.max_returns = jnp.full(buffer_size, -jnp.inf) + self.ages = jnp.zeros(buffer_size, dtype=jnp.uint32) + self.filled = jnp.zeros(buffer_size, dtype=jnp.bool_) + self.filled_count = jnp.zeros((1,), dtype=jnp.int32) + self.n_mutations = jnp.zeros(buffer_size, dtype=jnp.uint32) + + @partial(jax.jit, static_argnums=(0,)) + def reset(self): + return PLRBuffer( + ued_score=self.ued_score.value, + replay_prob=self.replay_prob, + buffer_size=self.buffer_size, + staleness_coef=self.staleness_coef, + temp=self.temp, + min_fill_ratio=self.min_fill_ratio, + use_robust_plr=self.use_robust_plr, + use_parallel_eval=self.use_parallel_eval, + levels=self.levels, + scores=self.scores, + max_returns=self.max_returns, + ages=self.ages, + filled=self.filled, + filled_count=self.filled_count, + n_mutations=self.n_mutations) + + partial(jax.jit, static_argnums=(0,)) + + def _get_replay_dist(self, scores, ages, filled): + # Score dist + if self.use_score_ranks: + sorted_idx = jnp.argsort(-scores) # Top first + scores = jnp.zeros(self.buffer_size, dtype=jnp.int32)\ + .at[sorted_idx]\ + .set(1/jnp.arange(self.buffer_size)) + + scores = scores*filled + # Scores in this implementation might contain NaN + # scores = jnp.nan_to_num( + # scores, nan=0.0, posinf=+jnp.inf, neginf=-jnp.inf) + score_dist = scores/self.temp + z = score_dist.sum() + z = jnp.where(jnp.equal(z, 0), 1, z) + score_dist = jax.lax.select( + jnp.greater(z, 0), + score_dist/z, + filled*1. # Assign equal weight to all present levels + ) + + # Staleness dist + staleness_scores = ages*filled + # staleness_scores = jnp.nan_to_num( + # staleness_scores, nan=0.0, posinf=+jnp.inf, neginf=-jnp.inf) + _z = staleness_scores.sum() + z = jnp.where(jnp.equal(_z, 0), 1, _z) + staleness_dist = jax.lax.select( + jnp.greater(_z, 0), + staleness_scores/z, + score_dist # If no solutions are stale, do not sample from staleness dist + ) + + # Replay dist + replay_dist = (1-self.staleness_coef)*score_dist \ + + self.staleness_coef*staleness_dist + + return replay_dist + + partial(jax.jit, static_argnums=(0,)) + + def _get_next_insert_idx(self, plr_buffer): + return jax.lax.cond( + jnp.greater(plr_buffer.buffer_size, plr_buffer.filled_count[0]), + lambda *_: plr_buffer.filled_count[0], + lambda *_: jnp.argmin(self._get_replay_dist(plr_buffer.scores, + plr_buffer.ages, plr_buffer.filled)) + ) + + @partial(jax.jit, static_argnums=(0, 3)) + def _sample_replay_levels(self, rng, plr_buffer, n): + def _sample_replay_level(carry, step): + ages = carry + subrng = step + replay_dist = self._get_replay_dist( + plr_buffer.scores, ages, plr_buffer.filled) + replay_idx = jax.random.choice(subrng, np.arange( + self.buffer_size), shape=(), p=replay_dist) + replay_level = jax.tree_map(lambda x: x.take( + replay_idx, axis=0), plr_buffer.levels) + + ages = ((ages + 1)*(plr_buffer.filled)).at[replay_idx].set(0) + + return ages, (replay_level, replay_idx) + + rng, *subrngs = jax.random.split(rng, n+1) + next_ages, (replay_levels, replay_idxs) = jax.lax.scan( + _sample_replay_level, + plr_buffer.ages, + jnp.array(subrngs) + ) + + next_plr_buffer = plr_buffer.replace( + ages=next_ages + ) + + return replay_levels, replay_idxs, next_plr_buffer + + def _sample_buffer_uniform(self, rng, plr_buffer, n): + rand_idxs = jax.random.choice(rng, np.arange( + self.buffer_size), shape=(n,), p=plr_buffer.filled) + levels = jax.tree_map(lambda x: x.take( + replay_idx, axis=0), plr_buffer.levels) + + return levels, rand_idxs, plr_buffer + + # Levels must be sampled sequentially, to account for staleness + @partial(jax.jit, static_argnums=(0, 4, 5)) + def sample(self, rng, plr_buffer, new_levels, n, random=False): + rng, replay_rng, sample_rng = jax.random.split(rng, 3) + + is_replay = jnp.greater( + self.replay_prob, jax.random.uniform(replay_rng)) + is_warm = jnp.greater_equal( + plr_buffer.filled.sum()/self.buffer_size, self.min_fill_ratio) + + if self.n_devices > 1: # Synchronize replay + is_replay = jax.lax.all_gather(is_replay, axis_name='device')[0] + is_warm = jnp.all(jax.lax.all_gather(is_warm, axis_name='device')) + + is_replay = jnp.logical_and(is_replay, is_warm) + + if random: + sample_fn = self._sample_buffer_uniform + else: + sample_fn = self._sample_replay_levels + + levels, level_idxs, next_plr_buffer = jax.lax.cond( + is_replay, + partial(sample_fn, n=n), + lambda *_: (new_levels, np.full(n, -1), plr_buffer), + *(sample_rng, plr_buffer) + ) + + # Update ages when not sampling replay + next_plr_buffer = jax.lax.cond( + is_replay, + lambda *_: next_plr_buffer, + lambda *_: next_plr_buffer.replace( + ages=(plr_buffer.ages+n)*(plr_buffer.filled)) + ) + + return levels, level_idxs, is_replay, next_plr_buffer + + @partial(jax.jit, static_argnums=(0,)) + def dedupe_levels(self, plr_buffer, levels, level_idxs): + if self.comparator_fn is not None and level_idxs.shape[-1] > 2: + def _check_equal(carry, step): + match_idxs, other_levels, is_self = carry + batch_idx, level = step + + matches = jax.vmap(self.comparator_fn, in_axes=( + 0, None))(other_levels, level) + + top2match, top2match_idxs = jax.lax.top_k(matches, 2) + + is_self_dupe = jnp.logical_and( + is_self, top2match[1]) # More than 1 match + is_dedupe_idx = jnp.logical_and( + is_self_dupe, jnp.greater(batch_idx, top2match_idxs[0])) + self_match_idx = top2match_idxs[0] * \ + is_dedupe_idx - (~is_dedupe_idx) + + _match_idx = jnp.where( + is_self, + self_match_idx, # only first + top2match_idxs[0], # use first matching index in buffer + ) + + match_idxs = jnp.where( + matches.any(), + match_idxs.at[batch_idx].set(_match_idx), + match_idxs + ) + + return (match_idxs, other_levels, is_self), None + + # dedupe among batch levels + batch_dupe_idxs = jnp.full_like(level_idxs, -1) + (batch_dupe_idxs, _, _), _ = jax.lax.scan( + _check_equal, + (batch_dupe_idxs, levels, True), + (np.arange(level_idxs.shape[-1]), levels) + ) + batch_dupe_mask = jnp.greater(batch_dupe_idxs, -1) + + # dedupe against PLR buffer levels + (level_idxs, _, _), _ = jax.lax.scan( + _check_equal, + (level_idxs, plr_buffer.levels, False), + (np.arange(level_idxs.shape[-1]), levels) + ) + + return level_idxs, batch_dupe_mask + else: + return level_idxs, jnp.zeros_like(level_idxs, dtype=jnp.bool_) + + @partial(jax.jit, static_argnums=(0, 7)) + def update(self, plr_buffer, levels, level_idxs, ued_scores, dupe_mask=None, info=None, ignore_val=-jnp.inf, parent_idxs=None): + # Note: parent_idxs are only used for mutated levels + done_masks = (ued_scores != ignore_val) + if dupe_mask is not None: + # Ignore duplicate levels in batch by treating them as not done + done_masks = jnp.logical_and(done_masks, ~dupe_mask) + + cur_n_mutations = plr_buffer.n_mutations + insert_mask = jnp.zeros((self.buffer_size,), dtype=jnp.bool_) + + def update_level_info(carry, step): + plr_buffer, insert_mask = carry + levels = plr_buffer.levels + scores = plr_buffer.scores + filled = plr_buffer.filled + + score, level, level_idx, done_mask, parent_idx, max_return = step + + next_insert_idx = self._get_next_insert_idx(plr_buffer) + is_new_level = jnp.greater(0, level_idx) + insert_idx = jnp.where( + is_new_level, + next_insert_idx, # new level + level_idx, + ) + + should_insert = jnp.greater_equal( + score, scores.at[insert_idx].get()) + should_insert = jnp.logical_and(should_insert, done_mask) + + is_existing_level = jnp.logical_and(~is_new_level, done_mask) + should_update = jnp.logical_and( + is_existing_level, ~insert_mask.at[insert_idx].get()) + should_insert = jnp.logical_and(should_insert, ~should_update) + next_insert_mask = jnp.where( + should_insert, + insert_mask.at[insert_idx].set(True), + insert_mask + ) + should_insert_or_update = jnp.logical_or( + should_insert, should_update) + + # Update max return if needed + next_max_returns = jnp.where( + should_insert_or_update, + plr_buffer.max_returns.at[insert_idx].set(max_return), + plr_buffer.max_returns + ) + + updated_level = jax.tree_map( + lambda x, y: jax.lax.select(should_insert, x, y), + level, + jax.tree_map(lambda x: x.at[insert_idx].get(), levels) + ) + next_levels = jax.tree_map( + lambda x, y: x.at[insert_idx].set(y), levels, updated_level) + + next_scores = jnp.where( + should_insert_or_update, + scores.at[insert_idx].set(score), + scores + ) + next_filled = jnp.where( + should_insert, + filled.at[insert_idx].set(True), + filled + ) + + plr_replace_kwargs = dict( + levels=next_levels, + scores=next_scores, + filled=next_filled, + filled_count=jnp.array([next_filled.sum()]), + max_returns=next_max_returns + ) + + # Update mutation count + n_mutations = plr_buffer.n_mutations + should_incr_n_mutations = jnp.logical_and( + jnp.not_equal(parent_idx, -1), should_insert) + should_reset_n_mutations = jnp.logical_and( + jnp.equal(parent_idx, -1), should_insert_or_update) + reset_n_mutations = jnp.where( + is_existing_level, + cur_n_mutations.at[insert_idx].get(), + 0 + ) + next_n_mutations = jnp.where( + should_incr_n_mutations, + n_mutations.at[insert_idx].set( + cur_n_mutations.at[parent_idx].get() + 1), + n_mutations + ) + next_n_mutations = jnp.where( + should_reset_n_mutations, + n_mutations.at[insert_idx].set(reset_n_mutations), + next_n_mutations + ) + + plr_replace_kwargs['n_mutations'] = next_n_mutations + + next_plr_buffer = plr_buffer.replace(**plr_replace_kwargs) + + return (next_plr_buffer, next_insert_mask), None + + if parent_idxs is None: + parent_idxs = jnp.full_like(level_idxs, -1) + + if plr_buffer.ued_score == UEDScore.MAX_MC.value: + max_returns = info['max_returns'] + else: + max_returns = jnp.full_like(level_idxs, -1) + carry = (ued_scores, levels, level_idxs, + done_masks, parent_idxs, max_returns) + + (next_plr_buffer, _), _ = jax.lax.scan( + update_level_info, + (plr_buffer, insert_mask), + carry + ) + + return next_plr_buffer + + @partial(jax.jit, static_argnums=(0,)) + def get_metrics(self, plr_buffer): + replay_dist = self._get_replay_dist( + plr_buffer.scores, + plr_buffer.ages, + plr_buffer.filled) + weighted_n_mutations = (plr_buffer.n_mutations*replay_dist).sum() + scores = jnp.where(plr_buffer.filled, plr_buffer.scores, 0) + weighted_ued_score = (scores*replay_dist).sum() + + weighted_age = (plr_buffer.ages*replay_dist).sum() + + return dict( + weighted_n_mutations=weighted_n_mutations, + weighted_ued_score=weighted_ued_score, + weighted_age=weighted_age + ) + + +class PopPLRManager(PLRManager): + def __init__(self, *, n_agents, **kwargs): + super().__init__(**kwargs) + + self.n_agents = n_agents + + @partial(jax.jit, static_argnums=(0, 1)) + def reset(self, n): + sup = super() + return jax.vmap(lambda *_: sup.reset())(np.arange(n)) + + partial(jax.jit, static_argnums=(0, 4, 5)) + + def sample(self, rng, plr_buffer, new_levels, n, random=False): + sup = super() + + rng, *vrngs = jax.random.split(rng, self.n_agents+1) + + return jax.vmap(sup.sample, in_axes=(0, 0, 0, None, None))( + jnp.array(vrngs), plr_buffer, new_levels, n, random + ) + + @partial(jax.jit, static_argnums=(0,)) + def dedupe_levels(self, plr_buffer, levels, level_idxs): + sup = super() + return jax.vmap(sup.dedupe_levels)(plr_buffer, levels, level_idxs) + + partial(jax.jit, static_argnums=(0, 7)) + + def update(self, plr_buffer, levels, level_idxs, ued_scores, dupe_mask=None, info=None, ignore_val=-jnp.inf, parent_idxs=None): + sup = super() + return jax.vmap(sup.update, in_axes=(0, 0, 0, 0, 0, 0, None, 0))( + plr_buffer, levels, level_idxs, ued_scores, dupe_mask, info, ignore_val, parent_idxs + ) + + partial(jax.jit, static_argnums=(0,)) + + def get_metrics(self, plr_buffer): + sup = super() + return jax.vmap(sup.get_metrics)(plr_buffer) diff --git a/src/minimax/util/rl/rolling_stats.py b/src/minimax/util/rl/rolling_stats.py new file mode 100644 index 0000000..7d992fd --- /dev/null +++ b/src/minimax/util/rl/rolling_stats.py @@ -0,0 +1,116 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial + +import jax +import jax.numpy as jnp + + +class RollingStats: + """ + This class tracks episodic stats, such as final returns + and env complexity metrics. Works on a per-env basis. + """ + + def __init__(self, names, step_metrics_names=[], window=None): + self.names = names + self.step_metric_names = step_metrics_names + self.window = window + + @partial(jax.jit, static_argnums=(0, 1)) + def reset_stats(self, batch_shape=(1,)): + stats = { + 'n_episodes': jnp.zeros((*(batch_shape), 1), dtype=jnp.uint32), + 'n_steps': jnp.zeros((*(batch_shape), 1), dtype=jnp.uint32), + 'reward': jnp.zeros((*(batch_shape), 1), dtype=jnp.float32), + } + stats.update({ + name: jnp.zeros((*(batch_shape), 1)) for name in self.names + }) + + if self.window is not None: + # Average over window + stats.update({ + f'{name}_buffer': jnp.zeros((*(batch_shape), self.window)) + for name in self.names + }) + + return stats + + @partial(jax.jit, static_argnums=(0,)) + def update_stats(self, stats, done, info, max_episodes=jnp.inf): + n_eps = stats['n_episodes'] + n_steps = stats['n_steps'] + + + for name in self.names: + # Update stat + if name not in info: + continue + + new_val = info[name] + + # NOTE: in MA settings sparse and dense rewards are per agent. + new_val = new_val.sum() + + # Only record first max_episode episodes + done = done*(n_eps < max_episodes) + if name in self.step_metric_names: + n_incr_prev = n_steps + n_incr_total = n_steps + 1 + _metric_done = True + else: + n_incr_prev = n_eps + n_incr_total = n_eps + done + _metric_done = done + + if self.window is None: + mean = stats[name] + new_mean = self._update_stat_mean( + new_val, mean, n_incr_total, _metric_done + ) + else: + buffer_key = f'{name}_buffer' + buffer = stats[buffer_key] + new_mean, buffer = self._update_stat_window( + new_val, buffer, n_incr_total, n_incr_prev, _metric_done + ) + stats.update({buffer_key: buffer}) + + stats.update({ + name: new_mean, + }) + + # Only update n_episodes based on real episodes + if name in self.step_metric_names: + stats.update({ + 'n_steps': n_incr_total + }) + else: + stats.update({ + 'n_episodes': n_incr_total + }) + + return stats + + @partial(jax.jit, static_argnums=(0,)) + def _update_stat_mean(self, new_val, mean, n_eps_total, done): + z = 1/jnp.maximum(1, n_eps_total) + new_mean = done*(mean*(1-z) + new_val*z) + (1-done)*mean + + return new_mean + + @partial(jax.jit, static_argnums=(0,)) + def _update_stat_window(self, new_val, buffer, n_eps_total, n_eps_prev, done): + cur_val = buffer[n_eps_prev % self.window] + new_val = done*new_val + (1-done)*cur_val + buffer = buffer.at[n_eps_prev % self.window].set(new_val) + new_mean = buffer.sum()/jnp.maximum(jnp.minimum(self.window, n_eps_total), 1) + + return new_mean, buffer diff --git a/src/minimax/util/rl/rollout_storage.py b/src/minimax/util/rl/rollout_storage.py new file mode 100644 index 0000000..59a032a --- /dev/null +++ b/src/minimax/util/rl/rollout_storage.py @@ -0,0 +1,227 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from collections import namedtuple + +import numpy as np +import jax +import jax.numpy as jnp + +import minimax.util.pytree as _tree_util +from .ued_scores import compute_episodic_stats + + +RolloutBatch = namedtuple( + 'RolloutBatch', ( + 'obs', + 'actions', + 'rewards', + 'dones', + 'log_pis', + 'values', + 'targets', + 'advantages', + 'carry' + )) + + +class RolloutStorage: + def __init__( + self, + discount, + gae_lambda, + n_envs, + n_eval, + n_steps, + action_space, + obs_space, + agent, + n_agents=1): + self.discount = discount + self.gae_lambda = gae_lambda + self.n_agents = n_agents + self.n_steps = n_steps + self.n_envs = n_envs + self.n_evals = n_eval + self.flat_batch_size = n_envs*self.n_evals + self.action_space = action_space + self.value_ensemble_size = agent.model.value_ensemble_size + + dummy_rng = jax.random.PRNGKey(0) + self.empty_obs = \ + jax.jax.tree_util.tree_map( + lambda x: jnp.empty((n_agents, n_steps, self.flat_batch_size) + + x.shape, dtype=x.dtype), + obs_space.sample(dummy_rng)) + + self.empty_action = \ + jax.jax.tree_util.tree_map( + lambda x: jnp.empty((n_agents, n_steps, self.flat_batch_size) + + x.shape, dtype=x.dtype), + action_space.sample(dummy_rng)) + + if agent.is_recurrent: + self.empty_carry = \ + agent.init_carry( + dummy_rng, batch_dims=(n_agents, self.n_steps, self.flat_batch_size)) + else: + self.empty_carry = None + + if agent.is_recurrent: + self.append = jax.vmap(self._append_with_carry, in_axes=0) + else: + self.append = jax.vmap(self._append_without_carry, in_axes=0) + self.get_batch = jax.vmap(self._get_batch) + self.get_return_stats = jax.vmap( + self._get_return_stats, in_axes=(0, None)) + + @partial(jax.jit, static_argnums=0) + def reset(self): + """ + Maintains a pytree of rollout transitions and metadata + """ + if self.empty_carry is None: + carry_buffer = None + else: + carry_buffer = self.empty_carry + + value_batch_size = (self.flat_batch_size,) + if self.value_ensemble_size > 1: + value_batch_size += (self.value_ensemble_size,) + + return { + "obs": self.empty_obs, + "actions": self.empty_action, + "rewards": jnp.empty( + (self.n_agents, self.n_steps, + self.flat_batch_size), dtype=jnp.float32 + ), + "dones": jnp.empty((self.n_agents, self.n_steps, self.flat_batch_size), dtype=jnp.uint8), + "log_pis_old": jnp.empty( + (self.n_agents, self.n_steps, + self.flat_batch_size), dtype=jnp.float32 + ), + "values_old": jnp.empty( + (self.n_agents, self.n_steps, * + value_batch_size), dtype=jnp.float32 + ), + "carry": carry_buffer, + "_t": jnp.zeros((self.n_agents,), dtype=jnp.uint32) # for vmap + } + + @partial(jax.jit, static_argnums=0) + def _append(self, buffer, obs, action, reward, done, log_pi, value, carry): + if carry is not None: + carry_buffer = _tree_util.pytree_set_array_at( + buffer["carry"], buffer["_t"], carry) + else: + carry_buffer = None + + return { + "obs": _tree_util.pytree_set_struct_at(buffer["obs"], buffer["_t"], obs), + "actions": _tree_util.pytree_set_struct_at(buffer["actions"], buffer["_t"], action), + "rewards": buffer["rewards"].at[buffer["_t"]].set(reward.squeeze()), + "dones": buffer["dones"].at[buffer["_t"]].set(done.squeeze()), + "log_pis_old": buffer["log_pis_old"].at[buffer["_t"]].set(log_pi), + "values_old": buffer["values_old"].at[buffer["_t"]].set(value.squeeze()), + "carry": carry_buffer, + "_t": (buffer["_t"] + 1) % self.n_steps, + } + + @partial(jax.jit, static_argnums=0) + def _append_with_carry(self, buffer, obs, action, reward, done, log_pi, value, carry): + return self._append(buffer, obs, action, reward, done, log_pi, value, carry) + + @partial(jax.jit, static_argnums=0) + def _append_without_carry(self, buffer, obs, action, reward, done, log_pi, value): + return self._append(buffer, obs, action, reward, done, log_pi, value, None) + + @partial(jax.jit, static_argnums=(0,)) + def _get_batch(self, buffer, last_value): + _dones = buffer["dones"] + rewards = buffer["rewards"] + + gae, target = self.compute_gae( + value=buffer["values_old"], + reward=rewards, + done=_dones, + last_value=last_value + ) + + # T x N x E x M --> N x T x EM if recurrent or N x TEM if not + if self.empty_carry is not None: + carry = buffer["carry"] + else: + carry = None + + batch_kwargs = dict( + obs=buffer["obs"], + actions=buffer["actions"], + rewards=rewards, + dones=_dones, + log_pis=buffer["log_pis_old"], + values=buffer["values_old"], + targets=target, + advantages=gae, + carry=carry, + ) + + return RolloutBatch(**batch_kwargs) + + def compute_gae(self, value, reward, done, last_value): + def _compute_gae(carry, step): + (discount, gae_lambda, gae, value_next) = carry + value, reward, done = step + + value_diff = discount*value_next*(1-done) - value + delta = reward + value_diff + + gae = delta + discount*gae_lambda*(1-done) * gae + + return (discount, gae_lambda, gae, value), gae + + value, reward, done = jnp.flip(value, 0), jnp.flip( + reward, 0), jnp.flip(done, 0) + + # Handle ensemble values, which have an extra ensemble dim at index -1 + if value.shape != done.shape: + reward = jnp.expand_dims(reward, -1) + done = jnp.expand_dims(done, -1) + + gae = jnp.zeros(value.shape[1:]) + _, advantages = jax.lax.scan( + _compute_gae, + (self.discount, self.gae_lambda, gae, last_value), + (value, reward, done), + length=len(reward) + ) + advantages = jnp.flip(advantages, 0) + targets = advantages + jnp.flip(value, 0) + + return advantages, targets + + def _get_return_stats(self, rollout, control_idxs=None): + if control_idxs is not None: + positive_signs = (control_idxs == 0) + reward_signs = -1*(positive_signs.astype(jnp.float32) - + (~positive_signs).astype(jnp.float32)) + rewards = rollout["rewards"]*reward_signs + else: + rewards = rollout["rewards"] + + pop_batch_shape = (self.n_steps, self.n_envs, self.n_evals) + rewards = jnp.flip(rewards.reshape(*pop_batch_shape), 0) + dones = jnp.flip(rollout["dones"].reshape(*pop_batch_shape), 0) + + return compute_episodic_stats(rewards, dones) + + def set_final_reward(self, rollout, reward): + rollout["rewards"] = rollout["rewards"].at[:, -1, :].set(reward) + + return rollout diff --git a/src/minimax/util/rl/rollout_storage_seperate.py b/src/minimax/util/rl/rollout_storage_seperate.py new file mode 100644 index 0000000..51f4d52 --- /dev/null +++ b/src/minimax/util/rl/rollout_storage_seperate.py @@ -0,0 +1,274 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from collections import namedtuple + +import jax +import jax.numpy as jnp + +import minimax.util.pytree as _tree_util +from .ued_scores import compute_episodic_stats + + +RolloutBatch = namedtuple( + 'RolloutBatch', ( + 'obs', + 'obs_shared', + 'actions', + 'rewards', + 'dones', + 'log_pis', + 'values', + 'targets', + 'advantages', + 'actor_carry', + 'critic_carry' + )) + + +class RolloutStorageSeperate: + def __init__( + self, + discount, + gae_lambda, + n_envs, + n_eval, + n_steps, + action_space, + obs_space, + obs_space_shared_shape, + agent, + n_agents=1): + self.discount = discount + self.gae_lambda = gae_lambda + + # NOTE: n_students refers to minimax's use of n_agents + # Since I added a multi agent env I need an actual n_agents + self.n_students = n_agents + self.n_env_agents = 2 + self.n_steps = n_steps + self.n_envs = n_envs + self.n_evals = n_eval + self.flat_batch_size = n_envs*self.n_evals + self.action_space = action_space + + dummy_rng = jax.random.PRNGKey(0) + self.empty_obs = \ + jax.jax.tree_util.tree_map( + lambda x: jnp.empty( + (self.n_students, n_steps, self.flat_batch_size, + self.n_env_agents) + x.shape, + dtype=x.dtype + ), + obs_space.sample(dummy_rng)) + + self.empty_obs_shared = jnp.empty( + (self.n_students, n_steps, self.flat_batch_size, + self.n_env_agents) + obs_space_shared_shape) + + self.empty_action = \ + jax.jax.tree_util.tree_map( + lambda x: jnp.empty( + (self.n_students, n_steps, self.flat_batch_size, self.n_env_agents) + x.shape, dtype=x.dtype), + action_space.sample(dummy_rng)) + + if agent.is_recurrent: + self.empty_actor_carry, self.empty_critic_carry = \ + agent.init_carry( + dummy_rng, batch_dims=( + self.n_students, n_steps, self.flat_batch_size, self.n_env_agents)) + else: + self.empty_actor_carry, self.empty_critic_carry = None, None + + if agent.is_recurrent: + self.append = jax.vmap(self._append_with_carry, in_axes=0) + else: + self.append = jax.vmap(self._append_without_carry, in_axes=0) + self.get_batch = jax.vmap(self._get_batch) + self.get_return_stats = jax.vmap( + self._get_return_stats, in_axes=(0, None)) + + @partial(jax.jit, static_argnums=0) + def reset(self): + """ + Maintains a pytree of rollout transitions and metadata + """ + if self.empty_actor_carry is None and self.empty_critic_carry is None: + actor_carry_buffer = None + critic_carry_buffer = None + else: + actor_carry_buffer = self.empty_actor_carry + critic_carry_buffer = self.empty_critic_carry + + value_batch_size = (self.flat_batch_size,) + + return { + "obs": self.empty_obs, + "obs_shared": self.empty_obs_shared, + "actions": self.empty_action, + "rewards": jnp.empty( + (self.n_students, self.n_steps, self.flat_batch_size, + self.n_env_agents), dtype=jnp.float32 + ), + "shaped_rewards": jnp.empty( + (self.n_students, self.n_steps, self.flat_batch_size, + self.n_env_agents), dtype=jnp.float32 + ), + "dones": jnp.empty((self.n_students, self.n_steps, self.flat_batch_size, + self.n_env_agents), dtype=jnp.uint8), + "log_pis_old": jnp.empty( + (self.n_students, self.n_steps, self.flat_batch_size, + self.n_env_agents), dtype=jnp.float32 + ), + "values_old": jnp.empty( + (self.n_students, self.n_steps, + *value_batch_size, self.n_env_agents), dtype=jnp.float32 + ), + "actor_carry": actor_carry_buffer, + "critic_carry": critic_carry_buffer, + "_t": jnp.zeros((self.n_students,), dtype=jnp.uint32) # for vmap + } + + @partial(jax.jit, static_argnums=0) + def _append(self, buffer, obs, obs_shared, action, reward, shaped_reward, done, log_pi, value, actor_carry, critic_carry): + if actor_carry is not None: + actor_carry_buffer = _tree_util.pytree_set_array_at( + buffer["actor_carry"], buffer["_t"], actor_carry) + else: + actor_carry_buffer = None + + if critic_carry is not None: + critic_carry_buffer = _tree_util.pytree_set_array_at( + buffer["critic_carry"], buffer["_t"], critic_carry) + else: + critic_carry_buffer = None + + obs = _tree_util.pytree_set_struct_at(buffer["obs"], buffer["_t"], obs) + obs_shared = _tree_util.pytree_set_struct_at( + buffer["obs_shared"], buffer["_t"], obs_shared) + + return { + "obs": obs, + "obs_shared": obs_shared, + "actions": _tree_util.pytree_set_struct_at(buffer["actions"], buffer["_t"], action), + "rewards": buffer["rewards"].at[buffer["_t"]].set(reward.squeeze()), + "shaped_rewards": buffer["shaped_rewards"].at[buffer["_t"]].set(shaped_reward.squeeze()), + "dones": buffer["dones"].at[buffer["_t"]].set(done.squeeze()), + "log_pis_old": buffer["log_pis_old"].at[buffer["_t"]].set(log_pi), + "values_old": buffer["values_old"].at[buffer["_t"]].set(value.squeeze()), + "actor_carry": actor_carry_buffer, + "critic_carry": critic_carry_buffer, + "_t": (buffer["_t"] + 1) % self.n_steps, + } + + @partial(jax.jit, static_argnums=0) + def _append_with_carry(self, buffer, obs, obs_shared, action, reward, shaped_reward, done, log_pi, value, actor_carry, critic_carry): + return self._append(buffer, obs, obs_shared, action, reward, shaped_reward, done, log_pi, value, actor_carry, critic_carry) + + @partial(jax.jit, static_argnums=0) + def _append_without_carry(self, buffer, obs, obs_shared, action, reward, shaped_reward, done, log_pi, value): + return self._append(buffer, obs, obs_shared, action, reward, shaped_reward, done, log_pi, value, None, None) + + @partial(jax.jit, static_argnums=(0,)) + # , intrinsic_reward_coeff=0.0): + def _get_batch(self, buffer, last_value, shaped_reward_coeff=None): + _dones = buffer["dones"] + rewards = buffer["rewards"] + + # if intrinsic_reward is not None: + # rewards = rewards + 0.0001 * intrinsic_reward_coeff * intrinsic_reward + # 0.0001 * + jax.debug.print("rewards buffer {r}", r=rewards.mean()) + + rewards = rewards + shaped_reward_coeff.mean() * \ + buffer["shaped_rewards"] + + jax.debug.print("rewards buffer {r}", r=rewards.mean()) + + gae, target = self.compute_gae( + value=buffer["values_old"], + reward=rewards, + done=_dones, + last_value=last_value + ) + + # T x N x E x M --> N x T x EM if recurrent or N x TEM if not + if self.empty_actor_carry is not None and self.empty_critic_carry is not None: + actor_carry = buffer["actor_carry"] + critic_carry = buffer["critic_carry"] + else: + actor_carry = None + critic_carry = None + + batch_kwargs = dict( + obs=buffer["obs"], + obs_shared=buffer["obs_shared"], + actions=buffer["actions"], + rewards=rewards, + dones=_dones, + log_pis=buffer["log_pis_old"], + values=buffer["values_old"], + targets=target, + advantages=gae, + actor_carry=actor_carry, + critic_carry=critic_carry + ) + return RolloutBatch(**batch_kwargs) + + def compute_gae(self, value, reward, done, last_value): + def _compute_gae(carry, step): + (discount, gae_lambda, gae, value_next) = carry + value, reward, done = step + + value_diff = discount*value_next*(1-done) - value + delta = reward + value_diff + + gae = delta + discount*gae_lambda*(1-done) * gae + + return (discount, gae_lambda, gae, value), gae + + value, reward, done = jnp.flip(value, 0), jnp.flip( + reward, 0), jnp.flip(done, 0) + + # Handle ensemble values, which have an extra ensemble dim at index -1 + if value.shape != done.shape: + reward = jnp.expand_dims(reward, -1) + done = jnp.expand_dims(done, -1) + + gae = jnp.zeros(value.shape[1:]) + _, advantages = jax.lax.scan( + _compute_gae, + (self.discount, self.gae_lambda, gae, last_value), + (value, reward, done), + length=len(reward) + ) + advantages = jnp.flip(advantages, 0) + targets = advantages + jnp.flip(value, 0) + + return advantages, targets + + def _get_return_stats(self, rollout, control_idxs=None): + if control_idxs is not None: + positive_signs = (control_idxs == 0) + reward_signs = -1*(positive_signs.astype(jnp.float32) - + (~positive_signs).astype(jnp.float32)) + rewards = rollout["rewards"]*reward_signs + else: + rewards = rollout["rewards"] + + pop_batch_shape = (self.n_steps, self.n_envs, self.n_evals) + rewards = jnp.flip(rewards.reshape(*pop_batch_shape), 0) + dones = jnp.flip(rollout["dones"].reshape(*pop_batch_shape), 0) + + return compute_episodic_stats(rewards, dones) + + def set_final_reward(self, rollout, reward): + rollout["rewards"] = rollout["rewards"].at[:, -1, :].set(reward) + + return rollout diff --git a/src/minimax/util/rl/training.py b/src/minimax/util/rl/training.py new file mode 100644 index 0000000..4e8c2db --- /dev/null +++ b/src/minimax/util/rl/training.py @@ -0,0 +1,201 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from typing import Any, Callable + +import jax +import jax.numpy as jnp +from flax import core +from flax import struct +import optax +import chex + +from .plr import PLRBuffer + + +class VmapMAPPOTrainState(struct.PyTreeNode): + n_iters: chex.Array + n_updates: chex.Array # per agent + n_grad_updates: chex.Array # per agent + actor_apply_fn: Callable = struct.field(pytree_node=False) + actor_params: core.FrozenDict[str, Any] + actor_tx: optax.GradientTransformation = struct.field(pytree_node=False) + actor_opt_state: optax.OptState + + critic_apply_fn: Callable = struct.field(pytree_node=False) + critic_params: core.FrozenDict[str, Any] + critic_tx: optax.GradientTransformation = struct.field(pytree_node=False) + critic_opt_state: optax.OptState + + shaped_reward_coeff: float = 0.0 + + plr_buffer: PLRBuffer = None + + def apply_gradients(self, *, actor_grads, critic_grads, **kwargs): + # Actor update + actor_updates, new_actor_opt_state = self.actor_tx.update( + actor_grads, self.actor_opt_state, self.actor_params) + new_actor_params = optax.apply_updates( + self.actor_params, actor_updates) + + # Critic update + critic_updates, new_critic_opt_state = self.critic_tx.update( + critic_grads, self.critic_opt_state, self.critic_params) + new_critic_params = optax.apply_updates( + self.critic_params, critic_updates) + + return self.replace( + n_grad_updates=self.n_updates + 1, + actor_params=new_actor_params, + actor_opt_state=new_actor_opt_state, + critic_params=new_critic_params, + critic_opt_state=new_critic_opt_state, + **kwargs, + ) + + @classmethod + def create( + cls, *, + actor_apply_fn, + actor_params, + actor_tx, + critic_apply_fn, + critic_params, + critic_tx, + **kwargs + ): + actor_opt_state = jax.vmap(actor_tx.init)(actor_params) + critic_opt_state = jax.vmap(critic_tx.init)(critic_params) + return cls( + n_iters=jnp.array(jax.vmap(lambda x: 0)( + actor_params), dtype=jnp.uint32), + n_updates=jnp.array(jax.vmap(lambda x: 0) + (actor_params), dtype=jnp.uint32), + n_grad_updates=jnp.array( + jax.vmap(lambda x: 0)(actor_params), dtype=jnp.uint32), + actor_apply_fn=actor_apply_fn, + actor_params=actor_params, + actor_tx=actor_tx, + actor_opt_state=actor_opt_state, + critic_apply_fn=critic_apply_fn, + critic_params=critic_params, + critic_tx=critic_tx, + critic_opt_state=critic_opt_state, + **kwargs, + ) + + def increment(self): + return self.replace( + n_iters=self.n_iters + 1, + ) + + def increment_updates(self): + return self.replace( + n_updates=self.n_updates + 1, + ) + + @property + def state_dict(self): + return dict( + n_iters=self.n_iters, + n_updates=self.n_updates, + n_grad_updates=self.n_grad_updates, + actor_params=self.actor_params, + actor_opt_state=self.actor_opt_state, + critic_params=self.critic_params, + critic_opt_state=self.critic_opt_state, + ) + + def set_new_shaped_reward_coeff(self, new_coeff): + return self.replace( + shaped_reward_coeff=new_coeff + ) + + def load_state_dict(self, state): + return self.replace( + n_iters=state['n_iters'], + n_updates=state['n_updates'], + n_grad_updates=state['n_grad_updates'], + actor_params=state['actor_params'], + actor_opt_state=state['actor_opt_state'], + critic_params=state['critic_params'], + critic_opt_state=state['critic_opt_state'], + ) + + +class VmapTrainState(struct.PyTreeNode): + n_iters: chex.Array + n_updates: chex.Array # per agent + n_grad_updates: chex.Array # per agent + apply_fn: Callable = struct.field(pytree_node=False) + params: core.FrozenDict[str, Any] + tx: optax.GradientTransformation = struct.field(pytree_node=False) + opt_state: optax.OptState + plr_buffer: PLRBuffer = None + + def apply_gradients(self, *, grads, **kwargs): + updates, new_opt_state = self.tx.update( + grads, self.opt_state, self.params) + new_params = optax.apply_updates(self.params, updates) + + return self.replace( + n_grad_updates=self.n_updates + 1, + params=new_params, + opt_state=new_opt_state, + **kwargs, + ) + + @classmethod + def create(cls, *, + apply_fn, + params, + tx, + **kwargs + ): + opt_state = jax.vmap(tx.init)(params) + return cls( + n_iters=jnp.array(jax.vmap(lambda x: 0)(params), dtype=jnp.uint32), + n_updates=jnp.array(jax.vmap(lambda x: 0) + (params), dtype=jnp.uint32), + n_grad_updates=jnp.array( + jax.vmap(lambda x: 0)(params), dtype=jnp.uint32), + apply_fn=apply_fn, + params=params, + tx=tx, + opt_state=opt_state, + **kwargs, + ) + + def increment(self): + return self.replace( + n_iters=self.n_iters + 1, + ) + + def increment_updates(self): + return self.replace( + n_updates=self.n_updates + 1, + ) + + @property + def state_dict(self): + return dict( + n_iters=self.n_iters, + n_updates=self.n_updates, + n_grad_updates=self.n_grad_updates, + params=self.params, + opt_state=self.opt_state + ) + + def load_state_dict(self, state): + return self.replace( + n_iters=state['n_iters'], + n_updates=state['n_updates'], + n_grad_updates=state['n_grad_updates'], + params=state['params'], + opt_state=state['opt_state'] + ) diff --git a/src/minimax/util/rl/ued_scores.py b/src/minimax/util/rl/ued_scores.py new file mode 100644 index 0000000..7492c48 --- /dev/null +++ b/src/minimax/util/rl/ued_scores.py @@ -0,0 +1,245 @@ +""" +Copyright (c) Meta Platforms, Inc. and affiliates. +All rights reserved. + +This source code is licensed under the license found in the +LICENSE file in the root directory of this source tree. +""" + +from functools import partial +from enum import Enum +from collections import namedtuple + +import einops +import jax +import jax.numpy as jnp + + +class UEDScore(Enum): + RELATIVE_REGRET = 1 + MEAN_RELATIVE_REGRET = 2 + POPULATION_REGRET = 3 + RETURN = 4 + NEG_RETURN = 5 + L1_VALUE_LOSS = 6 + POSITIVE_VALUE_LOSS = 7 + MAX_MC = 8 + VALUE_DISAGREEMENT = 9 + + +@partial(jax.jit, static_argnums=(2, 3)) +def compute_episodic_stats( + metrics, + dones, + time_average=False, + partial_metrics=0, + partial_steps=0, + return_partial=False): + env_batch_shape = dones.shape[1:] + n_episodes = jnp.zeros(env_batch_shape, dtype=jnp.uint32) + sum_ep_metrics = jnp.zeros(env_batch_shape, dtype=jnp.float32) + partial_metrics = jnp.zeros(env_batch_shape, dtype=jnp.float32) + max_metrics = jnp.zeros(env_batch_shape, dtype=jnp.float32) + steps = jnp.zeros(env_batch_shape, dtype=jnp.float32) + partial_steps = jnp.zeros(env_batch_shape, dtype=jnp.float32) + + def _compute_metrics(carry, step): + (n_episodes, + sum_ep_metrics, + max_metrics, + partial_metrics, + partial_steps) = carry + + _metrics, _dones = step + + partial_metrics += _metrics + partial_steps += 1 + + if time_average: + ep_metric = partial_metrics/partial_steps + else: + ep_metric = partial_metrics + + sum_ep_metrics += _dones*ep_metric + max_metrics = _dones * \ + jnp.maximum(max_metrics, ep_metric) + (1-_dones)*max_metrics + + n_episodes += _dones + + partial_metrics = (1-_dones)*partial_metrics + partial_steps = (1-_dones)*partial_steps + + return ( + n_episodes, + sum_ep_metrics, + max_metrics, + partial_metrics, + partial_steps + ), None + + (n_episodes, sum_ep_metrics, max_metrics, partial_metrics, partial_steps), _ = jax.lax.scan( + _compute_metrics, + (n_episodes, sum_ep_metrics, max_metrics, partial_metrics, partial_steps), + (metrics, dones), + length=len(metrics) + ) + + """Score per level based on two agents.""" + if len(n_episodes.shape) == 3: # n_parallel envs x n_parallel eval_envs x n_env_agents + n_episodes = n_episodes.sum(-1) + max_metrics = max_metrics.max(-1) + sum_ep_metrics = sum_ep_metrics.sum(-1) + + # Take mean over eval dimension + total_metrics_per_env = sum_ep_metrics.sum(-1) + n_episodes_per_env = n_episodes.sum(-1) + n_episodes_per_env = jnp.maximum(n_episodes_per_env, 1) + + # Take max over eval dimension + max_metrics_per_env = max_metrics.max(-1) + + return total_metrics_per_env/n_episodes_per_env, max_metrics_per_env + + +@partial(jax.jit, static_argnums=(0,)) +def _compute_ued_scores(score_name: UEDScore, batch: namedtuple, info=None): + """ + Compute UED score from a rollout batch. + Individual score functions return a tuple of mean_scores and max_scores, + where each is of dimension n_agents x n_envs. + """ + if score_name in [UEDScore.RELATIVE_REGRET, UEDScore.MEAN_RELATIVE_REGRET, UEDScore.POPULATION_REGRET]: + mean_scores, max_scores, score_info = compute_return(batch) + + elif score_name == UEDScore.RETURN: + mean_scores, max_scores, score_info = compute_return(batch) + + elif score_name == UEDScore.NEG_RETURN: + batch = batch._replace(rewards=-batch.rewards) + mean_scores, max_scores, score_info = compute_return(batch) + + elif score_name == UEDScore.MAX_MC: + mean_scores, max_scores, score_info = compute_max_mc(batch, info) + + elif score_name == UEDScore.L1_VALUE_LOSS: + mean_scores, max_scores, score_info = compute_l1_value_loss(batch) + + elif score_name == UEDScore.POSITIVE_VALUE_LOSS: + mean_scores, max_scores, score_info = compute_positive_value_loss( + batch) + + elif score_name == UEDScore.VALUE_DISAGREEMENT: + mean_scores, max_scores, score_info = compute_value_disagreement(batch) + + return mean_scores, max_scores, score_info + + +@partial(jax.jit, static_argnums=(0, 2, 4, 5)) +def compute_ued_scores(score_name: UEDScore, batch: namedtuple, n_eval: int, info: dict = None, ignore_val=None, per_agent=False): + if len(batch.dones.shape) == 3: + n_agents, n_steps, flat_batch_size = batch.dones.shape + else: + n_agents, n_steps, flat_batch_size, _ = batch.dones.shape + # pop_batch_shape = (n_agents, n_steps, flat_batch_size//n_eval, n_eval) + # batch = jax.tree_util.tree_map(lambda x: x.reshape( + # *pop_batch_shape, *x.shape[3:]), batch) + + batch = jax.tree_util.tree_map( + lambda x: einops.rearrange( + x, 'a t (s e) ... -> a t s e ...', + a=n_agents, t=n_steps, s=flat_batch_size, e=n_eval), batch) + + mean_env_returns_per_agent, max_env_returns_per_agent, score_info = \ + jax.vmap(_compute_ued_scores, in_axes=(None, 0, 0))( + score_name, batch, info + ) + + if score_name in [UEDScore.RELATIVE_REGRET, UEDScore.MEAN_RELATIVE_REGRET]: + assert len(mean_env_returns_per_agent) == 2, \ + "Standard PAIRED requires exactly 2 agents." + + if score_name == UEDScore.RELATIVE_REGRET: + scores = jnp.clip(max_env_returns_per_agent[1] + - mean_env_returns_per_agent[0], 0) + + elif score_name == UEDScore.MEAN_RELATIVE_REGRET: + scores = jnp.clip(mean_env_returns_per_agent[1] + - mean_env_returns_per_agent[0], 0) + + elif score_name == UEDScore.POPULATION_REGRET: + max_env_returns = max_env_returns_per_agent.max(0) + mean_env_returns = mean_env_returns_per_agent.mean(0) + scores = max_env_returns - mean_env_returns + else: + if per_agent: + scores = mean_env_returns_per_agent + max_scores = max_env_returns_per_agent + else: + scores = mean_env_returns_per_agent.mean(0) + max_scores = max_env_returns_per_agent.max(0) + + if ignore_val is not None: + if per_agent: + axis = (1, -1) if len(batch.dones.shape) == 3 else (1, -2, -1) + else: + axis = (0, 1, -1) if len(batch.dones.shape) == 3 else (0, 1, -2, -1) + + incomplete_idxs = batch.dones.sum(axis=axis) == 0 + + scores = jnp.where(incomplete_idxs, ignore_val, scores) + return scores, score_info + +# ======== UED score computations ======== + + +def compute_return(batch): + mean_scores, max_scores = compute_episodic_stats( + batch.rewards, batch.dones, time_average=False) + + return mean_scores, max_scores, None + + +def compute_l1_value_loss(batch): + mean_scores, max_scores = compute_episodic_stats( + jnp.abs(batch.advantages), batch.dones, time_average=True) + + return mean_scores, max_scores, None + + +def compute_positive_value_loss(batch): + mean_scores, max_scores = compute_episodic_stats( + jnp.clip(batch.advantages, 0), batch.dones, time_average=True) + + return mean_scores, max_scores, None + + +def compute_max_mc(batch, info): + _, max_env_returns_per_agent = \ + compute_episodic_stats(batch.rewards, batch.dones, time_average=False) + + max_returns = jnp.maximum(max_env_returns_per_agent, info['max_returns']) + # Multi Agent setting, we have mutlitple values. + if len(batch.dones.shape) == 4: + max_returns = jnp.concatenate( + [max_returns[jnp.newaxis, :, jnp.newaxis, jnp.newaxis], + max_returns[jnp.newaxis, :, jnp.newaxis, jnp.newaxis]], axis=-1 + ) + else: + max_returns = max_returns[jnp.newaxis, :, jnp.newaxis] + mean_scores, max_scores = compute_episodic_stats( + max_returns - batch.values, # Can be negative + batch.dones, + time_average=True + ) + + score_info = {'max_returns': max_env_returns_per_agent} + + return mean_scores, max_scores, score_info + + +def compute_value_disagreement(batch): + mean_scores, max_scores = compute_episodic_stats( + batch.values.std(-1), batch.dones, time_average=True + ) + + return mean_scores, max_scores, None diff --git a/src/run_results_txt/al_all_xpid_against_population_in_all_69_layouts_out.txt b/src/run_results_txt/al_all_xpid_against_population_in_all_69_layouts_out.txt new file mode 100644 index 0000000..5fc6345 --- /dev/null +++ b/src/run_results_txt/al_all_xpid_against_population_in_all_69_layouts_out.txt @@ -0,0 +1,54 @@ +Evaluating Overcooked-CoordRing6_9 against population for 9SEED_9_dr-overcookedNonexNonewNone_fs_FIXcoord_ring_6_9_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr3e-5g0.99cv0.5ce0.01e5mb1l0.95_pc0.2_h64cf32fc2se5ba_re_0 + +----------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 39.25+/- 4.04 (max: 127.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 16.08+/- 0.5815 (max: 26.51) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.5467+/- 0.05222 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 6.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 9.25 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +----------------------------------------------------------------------------------------- + +Evaluating Overcooked-ForcedCoord6_9 against population for 9SEED_9_dr-overcookedNonexNonewNone_fs_FIXforced_coord_6_9_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr3e-5g0.99cv0.5ce0.01e5mb1l0.95_pc0.2_h64cf32fc2se5ba_re_0 + + +------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 37.32+/- 6.509 (max: 140.4) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 12.47+/- 0.8239 (max: 28.85) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.389+/- 0.06325 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.2 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 1.99 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +------------------------------------------------------------------------------------------- + +Evaluating Overcooked-CounterCircuit6_9 against population for 9SEED_9_dr-overcookedNonexNonewNone_fs_FIXcounter_circuit_6_9_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr3e-5g0.99cv0.5ce0.01e5mb1l0.95_pc0.2_h64cf32fc2se5ba_re_0 + +---------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 45.24+/- 2.834 (max: 76.8) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 15.99+/- 0.572 (max: 29.8) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.7121+/- 0.04546 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 5.6 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 9.415 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.01 | +---------------------------------------------------------------------------------------------- + +Evaluating Overcooked-AsymmAdvantages6_9 against population for 9SEED_9_dr-overcookedNonexNonewNone_fs_FIXasymm_advantages_6_9_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr3e-5g0.99cv0.5ce0.01e5mb1l0.95_pc0.2_h64cf32fc2se5ba_re_0 + +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 152.0+/- 9.415 (max: 230.4) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 23.14+/- 1.555 (max: 45.34) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.9665+/- 0.01026 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 48.8 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 8.697 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.74 | +----------------------------------------------------------------------------------------------- +Evaluating Overcooked-CrampedRoom6_9 against population for 9SEED_9_dr-overcookedNonexNonewNone_fs_FIXcramped_room_6_9_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr3e-5g0.99cv0.5ce0.01e5mb1l0.95_pc0.2_h64cf32fc2se5ba_re_0 + +--------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 150.3+/- 5.073 (max: 203.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 19.68+/- 0.7425 (max: 41.5) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.9973+/- 0.0008802 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 94.6 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 12.67 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.98 | +--------------------------------------------------------------------------------------------- diff --git a/src/run_results_txt/eval_all_xpid_against_population_in_all_layouts_out.txt b/src/run_results_txt/eval_all_xpid_against_population_in_all_layouts_out.txt new file mode 100644 index 0000000..d516295 --- /dev/null +++ b/src/run_results_txt/eval_all_xpid_against_population_in_all_layouts_out.txt @@ -0,0 +1,138 @@ +Evaluating dr against population in Overcooked-CoordRing5_5 for dr-overcooked5x5w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 + +------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CoordRing5_5 | 15.35+/- 0.6965 (max: 23.2) | +| eval/a0:test_return_std:Overcooked-CoordRing5_5 | 13.78+/- 0.2437 (max: 17.16) | +| eval/a0:test_solved_rate:Overcooked-CoordRing5_5 | 0.1606+/- 0.01226 (max: 0.35) | +| min:eval/a0:test_return:Overcooked-CoordRing5_5 | 6.4 | +| min:eval/a0:test_return_std:Overcooked-CoordRing5_5 | 10.15 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing5_5 | 0.02 | +------------------------------------------------------------------------------------------ +Evaluating dr against population in Overcooked-ForcedCoord5_5 for dr-overcooked5x5w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 + +------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord5_5 | 0.6208+/- 0.06281 (max: 1.8) | +| eval/a0:test_return_std:Overcooked-ForcedCoord5_5 | 3.16+/- 0.199 (max: 5.724) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord5_5 | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord5_5 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord5_5 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord5_5 | 0.0 | +------------------------------------------------------------------------------------------- +Evaluating dr against population in Overcooked-CrampedRoom5_5 for dr-overcooked5x5w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 + +-------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom5_5 | 60.29+/- 4.157 (max: 99.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom5_5 | 27.89+/- 1.176 (max: 38.66) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom5_5 | 0.7073+/- 0.04138 (max: 0.98) | +| min:eval/a0:test_return:Overcooked-CrampedRoom5_5 | 15.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom5_5 | 14.53 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom5_5 | 0.18 | +-------------------------------------------------------------------------------------------- +Evaluating plr against population in Overcooked-CoordRing5_5 for plr-overcooked5x5w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 + +-------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing5_5 | 1.212+/- 0.2802 (max: 8.6) | +| eval/a0:test_return_std:Overcooked-CoordRing5_5 | 3.613+/- 0.4178 (max: 11.4) | +| eval/a0:test_solved_rate:Overcooked-CoordRing5_5 | 0.002708+/- 0.00102 (max: 0.04) | +| min:eval/a0:test_return:Overcooked-CoordRing5_5 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing5_5 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing5_5 | 0.0 | +-------------------------------------------------------------------------------------------- +Evaluating plr against population in Overcooked-ForcedCoord5_5 for plr-overcooked5x5w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 + +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord5_5 | 0.3167+/- 0.1021 (max: 3.2) | +| eval/a0:test_return_std:Overcooked-ForcedCoord5_5 | 1.334+/- 0.2965 (max: 7.332) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord5_5 | 0.0002083+/- 0.0002083 (max: 0.01) | +| min:eval/a0:test_return:Overcooked-ForcedCoord5_5 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord5_5 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord5_5 | 0.0 | +------------------------------------------------------------------------------------------------- +Evaluating plr against population in Overcooked-CrampedRoom5_5 for plr-overcooked5x5w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 + +------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom5_5 | 56.94+/- 3.333 (max: 89.0) | +| eval/a0:test_return_std:Overcooked-CrampedRoom5_5 | 26.26+/- 0.8617 (max: 35.83) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom5_5 | 0.7342+/- 0.03327 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom5_5 | 21.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom5_5 | 16.99 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom5_5 | 0.32 | +------------------------------------------------------------------------------------------- +Evaluating paired against population in Overcooked-CoordRing5_5 for paired-overcooked5x5w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 + +------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CoordRing5_5 | 19.15+/- 0.6706 (max: 26.4) | +| eval/a0:test_return_std:Overcooked-CoordRing5_5 | 13.84+/- 0.1905 (max: 16.12) | +| eval/a0:test_solved_rate:Overcooked-CoordRing5_5 | 0.2283+/- 0.01399 (max: 0.42) | +| eval/a1:test_return:Overcooked-CoordRing5_5 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing5_5 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing5_5 | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CoordRing5_5 | 9.6 | +| min:eval/a0:test_return_std:Overcooked-CoordRing5_5 | 11.61 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing5_5 | 0.06 | +| min:eval/a1:test_return:Overcooked-CoordRing5_5 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing5_5 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing5_5 | 0.0 | +------------------------------------------------------------------------------------------ +Evaluating paired against population in Overcooked-ForcedCoord5_5 for paired-overcooked5x5w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 + +------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord5_5 | 0.375+/- 0.05446 (max: 1.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord5_5 | 2.292+/- 0.2046 (max: 5.426) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord5_5 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord5_5 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord5_5 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord5_5 | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord5_5 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord5_5 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord5_5 | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord5_5 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord5_5 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord5_5 | 0.0 | +------------------------------------------------------------------------------------------- +Evaluating paired against population in Overcooked-CrampedRoom5_5 for paired-overcooked5x5w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 + +------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom5_5 | 66.46+/- 3.787 (max: 101.2) | +| eval/a0:test_return_std:Overcooked-CrampedRoom5_5 | 27.66+/- 1.134 (max: 42.87) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom5_5 | 0.7854+/- 0.02626 (max: 1.0) | +| eval/a1:test_return:Overcooked-CrampedRoom5_5 | 3.1+/- 0.4522 (max: 6.2) | +| eval/a1:test_return_std:Overcooked-CrampedRoom5_5 | 5.039+/- 0.735 (max: 10.08) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom5_5 | 0.01+/- 0.001459 (max: 0.02) | +| min:eval/a0:test_return:Overcooked-CrampedRoom5_5 | 27.2 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom5_5 | 17.95 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom5_5 | 0.44 | +| min:eval/a1:test_return:Overcooked-CrampedRoom5_5 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom5_5 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom5_5 | 0.0 | +------------------------------------------------------------------------------------------- +Evaluating accel against population in Overcooked-CoordRing5_5 for plr-overcooked5x5w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 + +----------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing5_5 | 19.28+/- 0.7542 (max: 27.2) | +| eval/a0:test_return_std:Overcooked-CoordRing5_5 | 13.71+/- 0.2366 (max: 17.09) | +| eval/a0:test_solved_rate:Overcooked-CoordRing5_5 | 0.2304+/- 0.0163 (max: 0.45) | +| min:eval/a0:test_return:Overcooked-CoordRing5_5 | 10.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing5_5 | 10.44 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing5_5 | 0.04 | +----------------------------------------------------------------------------------------- +Evaluating accel against population in Overcooked-ForcedCoord5_5 for plr-overcooked5x5w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 + +---------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord5_5 | 4.071+/- 0.6144 (max: 12.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord5_5 | 6.018+/- 0.5814 (max: 11.54) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord5_5 | 0.00625+/- 0.001826 (max: 0.05) | +| min:eval/a0:test_return:Overcooked-ForcedCoord5_5 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord5_5 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord5_5 | 0.0 | +---------------------------------------------------------------------------------------------- +Evaluating accel against population in Overcooked-CrampedRoom5_5 for plr-overcooked5x5w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 + +-------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom5_5 | 63.35+/- 4.381 (max: 97.6) | +| eval/a0:test_return_std:Overcooked-CrampedRoom5_5 | 23.11+/- 0.861 (max: 32.16) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom5_5 | 0.7387+/- 0.04596 (max: 0.99) | +| min:eval/a0:test_return:Overcooked-CrampedRoom5_5 | 10.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom5_5 | 12.14 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom5_5 | 0.06 | +-------------------------------------------------------------------------------------------- diff --git a/src/run_results_txt/eval_xpid_all_cnn_lstm_out.txt b/src/run_results_txt/eval_xpid_all_cnn_lstm_out.txt new file mode 100644 index 0000000..07d6024 --- /dev/null +++ b/src/run_results_txt/eval_xpid_all_cnn_lstm_out.txt @@ -0,0 +1,2381 @@ +Evaluating DR_CNN-LSTM_SEED1 against population in Overcooked-CoordRing6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [4.199999809265137, 5.400000095367432, 18.600000381469727, 19.0, 17.600000381469727, 16.399999618530273, 12.799999237060547, 12.0, 31.799999237060547, 32.599998474121094, 15.399999618530273, 15.199999809265137, 13.0, 11.0, 25.799999237060547, 22.399999618530273, 16.399999618530273, 20.19999885559082, 18.0, 15.59999942779541, 18.799999237060547, 20.399999618530273, 18.600000381469727, 19.399999618530273, 13.399999618530273, 18.0, 19.19999885559082, 18.0, 18.399999618530273, 18.799999237060547, 17.0, 13.59999942779541, 26.19999885559082, 28.799999237060547, 29.799999237060547, 31.0, 6.599999904632568, 3.799999952316284, 23.799999237060547, 25.399999618530273, 29.0, 26.799999237060547, 20.0, 14.399999618530273, 19.799999237060547, 17.19999885559082, 21.399999618530273, 21.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [4.199999809265137, 5.400000095367432, 12.799999237060547, 12.0, 13.0, 11.0, 18.0, 15.59999942779541, 13.399999618530273, 18.0, 17.0, 13.59999942779541, 6.599999904632568, 3.799999952316284, 20.0, 14.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [18.600000381469727, 19.0, 31.799999237060547, 32.599998474121094, 25.799999237060547, 22.399999618530273, 18.799999237060547, 20.399999618530273, 19.19999885559082, 18.0, 26.19999885559082, 28.799999237060547, 23.799999237060547, 25.399999618530273, 19.799999237060547, 17.19999885559082] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [17.600000381469727, 16.399999618530273, 15.399999618530273, 15.199999809265137, 16.399999618530273, 20.19999885559082, 18.600000381469727, 19.399999618530273, 18.399999618530273, 18.799999237060547, 29.799999237060547, 31.0, 29.0, 26.799999237060547, 21.399999618530273, 21.399999618530273] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 18.8+/- 0.9826 (max: 32.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 20.99+/- 1.315 (max: 31.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 12.42+/- 1.266 (max: 20.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 22.99+/- 1.244 (max: 32.6) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 13.49+/- 0.304 (max: 17.52) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 13.64+/- 0.5608 (max: 17.52) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.74+/- 0.6188 (max: 15.36) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 14.09+/- 0.3293 (max: 17.37) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.226+/- 0.02245 (max: 0.6) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.2469+/- 0.04221 (max: 0.53) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.1187+/- 0.01932 (max: 0.23) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.3125+/- 0.03586 (max: 0.6) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 3.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 15.2 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 3.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 17.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 7.846 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 10.69 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 7.846 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 12.28 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.07 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.17 | +----------------------------------------------------------------------------------------------- +Evaluating DR_CNN-LSTM_SEED1 against population in Overcooked-ForcedCoord6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.3999999761581421, 0.0, 0.3999999761581421, 0.19999998807907104, 0.0, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 0.19999998807907104, 0.0, 1.0, 0.3999999761581421, 0.7999999523162842, 0.0, 0.0, 0.19999998807907104, 0.5999999642372131, 0.0, 0.3999999761581421, 0.0, 1.1999999284744263, 0.0, 0.19999998807907104, 0.19999998807907104, 0.5999999642372131, 0.19999998807907104, 0.19999998807907104, 0.3999999761581421, 0.5999999642372131, 0.0, 0.3999999761581421, 0.0, 0.0, 0.19999998807907104, 0.3999999761581421, 0.19999998807907104, 0.7999999523162842, 0.19999998807907104, 0.19999998807907104, 0.0, 0.5999999642372131, 0.19999998807907104, 0.5999999642372131, 0.19999998807907104, 1.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.3999999761581421, 0.0, 1.1999999284744263, 0.0, 1.0, 0.3999999761581421, 0.5999999642372131, 0.0, 0.19999998807907104, 0.19999998807907104, 0.5999999642372131, 0.0, 0.3999999761581421, 0.19999998807907104, 0.5999999642372131, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.3999999761581421, 0.19999998807907104, 1.1999999284744263, 0.0, 0.7999999523162842, 0.0, 0.3999999761581421, 0.0, 0.5999999642372131, 0.19999998807907104, 0.3999999761581421, 0.0, 0.7999999523162842, 0.19999998807907104, 0.5999999642372131, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.19999998807907104, 0.0, 0.0, 0.19999998807907104, 1.1999999284744263, 0.0, 0.19999998807907104, 0.3999999761581421, 0.0, 0.19999998807907104, 0.19999998807907104, 0.0, 1.0, 0.0] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.325+/- 0.05144 (max: 1.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.225+/- 0.09106 (max: 1.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.375+/- 0.08921 (max: 1.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.375+/- 0.08732 (max: 1.2) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 1.944+/- 0.2301 (max: 4.75) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 1.366+/- 0.4049 (max: 4.75) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 2.231+/- 0.3882 (max: 4.75) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 2.236+/- 0.387 (max: 4.75) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------- +Evaluating DR_CNN-LSTM_SEED1 against population in Overcooked-CounterCircuit6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [11.0, 8.0, 18.0, 12.399999618530273, 10.800000190734863, 9.0, 5.799999713897705, 4.799999713897705, 11.0, 8.59999942779541, 1.7999999523162842, 0.0, 3.799999952316284, 2.0, 11.399999618530273, 8.0, 10.0, 5.400000095367432, 7.599999904632568, 4.199999809265137, 19.600000381469727, 17.799999237060547, 8.0, 3.3999998569488525, 8.399999618530273, 6.0, 16.399999618530273, 12.399999618530273, 4.799999713897705, 2.799999952316284, 10.59999942779541, 7.799999713897705, 14.199999809265137, 9.59999942779541, 6.599999904632568, 3.5999999046325684, 6.599999904632568, 5.0, 15.199999809265137, 11.399999618530273, 8.800000190734863, 2.3999998569488525, 10.800000190734863, 7.599999904632568, 18.0, 13.0, 7.199999809265137, 2.3999998569488525] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [11.0, 8.0, 5.799999713897705, 4.799999713897705, 3.799999952316284, 2.0, 7.599999904632568, 4.199999809265137, 8.399999618530273, 6.0, 10.59999942779541, 7.799999713897705, 6.599999904632568, 5.0, 10.800000190734863, 7.599999904632568] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [18.0, 12.399999618530273, 11.0, 8.59999942779541, 11.399999618530273, 8.0, 19.600000381469727, 17.799999237060547, 16.399999618530273, 12.399999618530273, 14.199999809265137, 9.59999942779541, 15.199999809265137, 11.399999618530273, 18.0, 13.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [10.800000190734863, 9.0, 1.7999999523162842, 0.0, 10.0, 5.400000095367432, 8.0, 3.3999998569488525, 4.799999713897705, 2.799999952316284, 6.599999904632568, 3.5999999046325684, 8.800000190734863, 2.3999998569488525, 7.199999809265137, 2.3999998569488525] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 8.625+/- 0.6853 (max: 19.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 5.437+/- 0.8184 (max: 10.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 6.875+/- 0.6519 (max: 11.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 13.56+/- 0.9017 (max: 19.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 10.74+/- 0.468 (max: 18.22) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 8.364+/- 0.7386 (max: 11.83) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 9.884+/- 0.4215 (max: 12.45) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 13.96+/- 0.4691 (max: 18.22) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.05458+/- 0.01037 (max: 0.26) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.01125+/- 0.004171 (max: 0.05) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.02062+/- 0.005588 (max: 0.07) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.1319+/- 0.019 (max: 0.26) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 2.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 8.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 6.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 10.58 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.02 | +------------------------------------------------------------------------------------------------------ +Evaluating DR_CNN-LSTM_SEED1 against population in Overcooked-AsymmAdvantages6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [2.3999998569488525, 2.200000047683716, 1.399999976158142, 8.199999809265137, 2.799999952316284, 1.5999999046325684, 2.799999952316284, 5.0, 1.399999976158142, 41.20000076293945, 2.5999999046325684, 1.1999999284744263, 2.3999998569488525, 34.599998474121094, 2.3999998569488525, 56.0, 3.3999998569488525, 12.399999618530273, 1.7999999523162842, 13.59999942779541, 2.799999952316284, 10.199999809265137, 3.5999999046325684, 4.199999809265137, 2.5999999046325684, 1.7999999523162842, 2.3999998569488525, 18.399999618530273, 0.5999999642372131, 2.200000047683716, 1.7999999523162842, 3.5999999046325684, 1.7999999523162842, 26.799999237060547, 3.1999998092651367, 1.399999976158142, 2.0, 3.3999998569488525, 1.7999999523162842, 15.799999237060547, 4.199999809265137, 4.799999713897705, 2.3999998569488525, 5.199999809265137, 2.200000047683716, 23.799999237060547, 3.799999952316284, 0.5999999642372131] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [2.3999998569488525, 2.200000047683716, 2.799999952316284, 5.0, 2.3999998569488525, 34.599998474121094, 1.7999999523162842, 13.59999942779541, 2.5999999046325684, 1.7999999523162842, 1.7999999523162842, 3.5999999046325684, 2.0, 3.3999998569488525, 2.3999998569488525, 5.199999809265137] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [1.399999976158142, 8.199999809265137, 1.399999976158142, 41.20000076293945, 2.3999998569488525, 56.0, 2.799999952316284, 10.199999809265137, 2.3999998569488525, 18.399999618530273, 1.7999999523162842, 26.799999237060547, 1.7999999523162842, 15.799999237060547, 2.200000047683716, 23.799999237060547] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [2.799999952316284, 1.5999999046325684, 2.5999999046325684, 1.1999999284744263, 3.3999998569488525, 12.399999618530273, 3.5999999046325684, 4.199999809265137, 0.5999999642372131, 2.200000047683716, 3.1999998092651367, 1.399999976158142, 4.199999809265137, 4.799999713897705, 3.799999952316284, 0.5999999642372131] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 7.433+/- 1.638 (max: 56.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 3.287+/- 0.6906 (max: 12.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 5.475+/- 2.071 (max: 34.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 13.54+/- 4.072 (max: 56.0) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 9.837+/- 0.9617 (max: 32.25) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 7.049+/- 0.5832 (max: 12.26) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 9.034+/- 1.551 (max: 30.44) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 13.43+/- 2.128 (max: 32.25) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.07521+/- 0.02527 (max: 0.81) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0075+/- 0.003476 (max: 0.05) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.04187+/- 0.03055 (max: 0.49) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1762+/- 0.06323 (max: 0.81) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.6 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.6 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 1.8 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 1.4 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 3.412 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 3.412 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 5.724 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 5.103 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating DR_CNN-LSTM_SEED1 against population in Overcooked-CrampedRoom6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [10.399999618530273, 10.59999942779541, 69.4000015258789, 65.79999542236328, 86.79999542236328, 90.5999984741211, 16.799999237060547, 20.399999618530273, 78.0, 78.5999984741211, 99.19999694824219, 101.79999542236328, 19.0, 16.799999237060547, 99.5999984741211, 95.4000015258789, 103.5999984741211, 102.79999542236328, 21.0, 18.799999237060547, 95.0, 92.0, 80.5999984741211, 82.4000015258789, 7.599999904632568, 7.599999904632568, 80.0, 71.19999694824219, 92.5999984741211, 98.19999694824219, 17.19999885559082, 19.19999885559082, 92.0, 86.0, 102.39999389648438, 106.5999984741211, 18.600000381469727, 15.59999942779541, 85.5999984741211, 88.4000015258789, 98.79999542236328, 102.79999542236328, 22.799999237060547, 20.0, 101.0, 92.4000015258789, 93.5999984741211, 94.19999694824219] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [10.399999618530273, 10.59999942779541, 16.799999237060547, 20.399999618530273, 19.0, 16.799999237060547, 21.0, 18.799999237060547, 7.599999904632568, 7.599999904632568, 17.19999885559082, 19.19999885559082, 18.600000381469727, 15.59999942779541, 22.799999237060547, 20.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [69.4000015258789, 65.79999542236328, 78.0, 78.5999984741211, 99.5999984741211, 95.4000015258789, 95.0, 92.0, 80.0, 71.19999694824219, 92.0, 86.0, 85.5999984741211, 88.4000015258789, 101.0, 92.4000015258789] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [86.79999542236328, 90.5999984741211, 99.19999694824219, 101.79999542236328, 103.5999984741211, 102.79999542236328, 80.5999984741211, 82.4000015258789, 92.5999984741211, 98.19999694824219, 102.39999389648438, 106.5999984741211, 98.79999542236328, 102.79999542236328, 93.5999984741211, 94.19999694824219] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 66.04+/- 5.285 (max: 106.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 96.06+/- 1.954 (max: 106.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 16.4+/- 1.194 (max: 22.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 85.65+/- 2.703 (max: 101.0) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 24.58+/- 1.152 (max: 36.82) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 31.47+/- 0.8677 (max: 36.82) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 14.41+/- 0.4756 (max: 16.73) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 27.86+/- 0.8525 (max: 34.67) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6948+/- 0.05278 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.955+/- 0.008266 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1906+/- 0.02437 (max: 0.32) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9387+/- 0.009995 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 7.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 80.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 7.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 65.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.87 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 26.54 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.87 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 22.3 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.86 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.87 | +------------------------------------------------------------------------------------------------- + + + + + + + + + +Evaluating DR_CNN-LSTM_SEED2 against population in Overcooked-CoordRing6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [4.199999809265137, 3.1999998092651367, 15.199999809265137, 11.800000190734863, 14.399999618530273, 15.0, 11.800000190734863, 11.800000190734863, 20.399999618530273, 19.399999618530273, 13.59999942779541, 14.199999809265137, 9.399999618530273, 10.399999618530273, 19.399999618530273, 16.399999618530273, 12.399999618530273, 12.399999618530273, 13.0, 9.199999809265137, 11.800000190734863, 12.399999618530273, 10.399999618530273, 11.59999942779541, 12.0, 11.199999809265137, 14.199999809265137, 15.799999237060547, 12.199999809265137, 15.199999809265137, 11.800000190734863, 14.0, 21.0, 21.399999618530273, 21.0, 23.0, 3.1999998092651367, 2.0, 13.0, 15.799999237060547, 20.0, 21.19999885559082, 16.799999237060547, 14.59999942779541, 8.800000190734863, 10.399999618530273, 8.399999618530273, 12.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [4.199999809265137, 3.1999998092651367, 11.800000190734863, 11.800000190734863, 9.399999618530273, 10.399999618530273, 13.0, 9.199999809265137, 12.0, 11.199999809265137, 11.800000190734863, 14.0, 3.1999998092651367, 2.0, 16.799999237060547, 14.59999942779541] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [15.199999809265137, 11.800000190734863, 20.399999618530273, 19.399999618530273, 19.399999618530273, 16.399999618530273, 11.800000190734863, 12.399999618530273, 14.199999809265137, 15.799999237060547, 21.0, 21.399999618530273, 13.0, 15.799999237060547, 8.800000190734863, 10.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [14.399999618530273, 15.0, 13.59999942779541, 14.199999809265137, 12.399999618530273, 12.399999618530273, 10.399999618530273, 11.59999942779541, 12.199999809265137, 15.199999809265137, 21.0, 23.0, 20.0, 21.19999885559082, 8.399999618530273, 12.399999618530273] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CoordRing6_9 | 13.4+/- 0.6984 (max: 23.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 14.84+/- 1.061 (max: 23.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 9.912+/- 1.116 (max: 16.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 15.45+/- 0.9924 (max: 21.4) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 11.96+/- 0.2779 (max: 16.06) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.84+/- 0.3689 (max: 15.07) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 11.33+/- 0.6342 (max: 13.96) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 12.72+/- 0.3488 (max: 16.06) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.09854+/- 0.01253 (max: 0.34) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.09937+/- 0.02589 (max: 0.31) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.06625+/- 0.0119 (max: 0.13) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.13+/- 0.02297 (max: 0.34) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 2.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 8.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 8.8 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 6.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 9.992 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.42 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.04 | +------------------------------------------------------------------------------------------------ +Evaluating DR_CNN-LSTM_SEED2 against population in Overcooked-ForcedCoord6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [2.3999998569488525, 0.0, 3.5999999046325684, 0.19999998807907104, 1.1999999284744263, 0.0, 1.399999976158142, 0.7999999523162842, 2.200000047683716, 0.19999998807907104, 1.1999999284744263, 0.0, 2.3999998569488525, 1.1999999284744263, 1.1999999284744263, 0.7999999523162842, 1.5999999046325684, 1.0, 1.7999999523162842, 0.3999999761581421, 1.5999999046325684, 0.0, 1.7999999523162842, 0.19999998807907104, 2.0, 0.3999999761581421, 2.799999952316284, 0.7999999523162842, 3.0, 0.19999998807907104, 2.799999952316284, 0.3999999761581421, 1.5999999046325684, 0.19999998807907104, 0.3999999761581421, 0.19999998807907104, 1.399999976158142, 0.0, 1.1999999284744263, 0.0, 0.7999999523162842, 0.0, 2.200000047683716, 0.0, 3.1999998092651367, 0.0, 3.0, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [2.3999998569488525, 0.0, 1.399999976158142, 0.7999999523162842, 2.3999998569488525, 1.1999999284744263, 1.7999999523162842, 0.3999999761581421, 2.0, 0.3999999761581421, 2.799999952316284, 0.3999999761581421, 1.399999976158142, 0.0, 2.200000047683716, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [3.5999999046325684, 0.19999998807907104, 2.200000047683716, 0.19999998807907104, 1.1999999284744263, 0.7999999523162842, 1.5999999046325684, 0.0, 2.799999952316284, 0.7999999523162842, 1.5999999046325684, 0.19999998807907104, 1.1999999284744263, 0.0, 3.1999998092651367, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.5999999046325684, 1.0, 1.7999999523162842, 0.19999998807907104, 3.0, 0.19999998807907104, 0.3999999761581421, 0.19999998807907104, 0.7999999523162842, 0.0, 3.0, 0.19999998807907104] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.125+/- 0.1505 (max: 3.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.925+/- 0.2496 (max: 3.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.225+/- 0.2407 (max: 2.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 1.225+/- 0.2977 (max: 3.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 3.793+/- 0.3557 (max: 7.684) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 3.407+/- 0.6112 (max: 7.681) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 4.075+/- 0.607 (max: 6.94) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 3.898+/- 0.6569 (max: 7.684) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0002083+/- 0.0002083 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating DR_CNN-LSTM_SEED2 against population in Overcooked-CounterCircuit6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [9.199999809265137, 9.399999618530273, 16.399999618530273, 10.399999618530273, 12.799999237060547, 8.59999942779541, 5.199999809265137, 6.199999809265137, 6.199999809265137, 4.799999713897705, 1.5999999046325684, 0.5999999642372131, 1.7999999523162842, 1.1999999284744263, 7.399999618530273, 5.599999904632568, 6.399999618530273, 4.0, 7.799999713897705, 6.0, 18.399999618530273, 14.0, 9.800000190734863, 5.199999809265137, 7.599999904632568, 5.400000095367432, 16.19999885559082, 11.800000190734863, 3.3999998569488525, 2.5999999046325684, 9.0, 7.399999618530273, 18.399999618530273, 11.0, 6.199999809265137, 2.799999952316284, 6.399999618530273, 3.799999952316284, 11.0, 7.199999809265137, 8.199999809265137, 4.199999809265137, 11.199999809265137, 9.0, 25.799999237060547, 16.600000381469727, 9.59999942779541, 3.799999952316284] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [9.199999809265137, 9.399999618530273, 5.199999809265137, 6.199999809265137, 1.7999999523162842, 1.1999999284744263, 7.799999713897705, 6.0, 7.599999904632568, 5.400000095367432, 9.0, 7.399999618530273, 6.399999618530273, 3.799999952316284, 11.199999809265137, 9.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [16.399999618530273, 10.399999618530273, 6.199999809265137, 4.799999713897705, 7.399999618530273, 5.599999904632568, 18.399999618530273, 14.0, 16.19999885559082, 11.800000190734863, 18.399999618530273, 11.0, 11.0, 7.199999809265137, 25.799999237060547, 16.600000381469727] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [12.799999237060547, 8.59999942779541, 1.5999999046325684, 0.5999999642372131, 6.399999618530273, 4.0, 9.800000190734863, 5.199999809265137, 3.3999998569488525, 2.5999999046325684, 6.199999809265137, 2.799999952316284, 8.199999809265137, 4.199999809265137, 9.59999942779541, 3.799999952316284] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 8.283+/- 0.7421 (max: 25.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 5.612+/- 0.8486 (max: 12.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 6.662+/- 0.6901 (max: 11.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 12.57+/- 1.45 (max: 25.8) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 10.52+/- 0.4794 (max: 21.41) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 8.498+/- 0.5618 (max: 12.5) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 9.672+/- 0.5273 (max: 12.46) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 13.39+/- 0.8337 (max: 21.41) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.05+/- 0.01179 (max: 0.38) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.00875+/- 0.00499 (max: 0.08) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.02+/- 0.005845 (max: 0.07) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.1212+/- 0.02711 (max: 0.38) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.6 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.6 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 1.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 4.8 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 3.412 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 3.412 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 4.75 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 8.98 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating DR_CNN-LSTM_SEED2 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [1.1999999284744263, 3.1999998092651367, 0.5999999642372131, 6.799999713897705, 1.5999999046325684, 1.0, 1.399999976158142, 7.0, 1.1999999284744263, 41.39999771118164, 3.3999998569488525, 1.0, 2.0, 31.399999618530273, 1.399999976158142, 52.599998474121094, 2.3999998569488525, 8.199999809265137, 1.399999976158142, 13.0, 1.0, 6.0, 4.400000095367432, 1.399999976158142, 1.1999999284744263, 1.0, 1.399999976158142, 15.59999942779541, 0.5999999642372131, 0.7999999523162842, 1.1999999284744263, 5.0, 1.0, 27.19999885559082, 1.7999999523162842, 0.3999999761581421, 0.5999999642372131, 3.1999998092651367, 1.0, 15.0, 1.5999999046325684, 2.5999999046325684, 1.5999999046325684, 3.799999952316284, 1.0, 18.600000381469727, 3.3999998569488525, 0.5999999642372131] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [1.1999999284744263, 3.1999998092651367, 1.399999976158142, 7.0, 2.0, 31.399999618530273, 1.399999976158142, 13.0, 1.1999999284744263, 1.0, 1.1999999284744263, 5.0, 0.5999999642372131, 3.1999998092651367, 1.5999999046325684, 3.799999952316284] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.5999999642372131, 6.799999713897705, 1.1999999284744263, 41.39999771118164, 1.399999976158142, 52.599998474121094, 1.0, 6.0, 1.399999976158142, 15.59999942779541, 1.0, 27.19999885559082, 1.0, 15.0, 1.0, 18.600000381469727] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [1.5999999046325684, 1.0, 3.3999998569488525, 1.0, 2.3999998569488525, 8.199999809265137, 4.400000095367432, 1.399999976158142, 0.5999999642372131, 0.7999999523162842, 1.7999999523162842, 0.3999999761581421, 1.5999999046325684, 2.5999999046325684, 3.3999998569488525, 0.5999999642372131] +-------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 6.358+/- 1.576 (max: 52.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 2.2+/- 0.4943 (max: 8.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 4.887+/- 1.934 (max: 31.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 11.99+/- 3.994 (max: 52.6) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 8.624+/- 0.9728 (max: 33.54) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 5.68+/- 0.4988 (max: 9.837) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 8.102+/- 1.428 (max: 25.96) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 12.09+/- 2.282 (max: 33.54) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.06833+/- 0.02479 (max: 0.72) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.04625+/- 0.03188 (max: 0.5) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1581+/- 0.0623 (max: 0.72) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.4 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.4 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.6 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.6 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 2.8 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 2.8 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 3.412 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 3.412 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +-------------------------------------------------------------------------------------------------------- +Evaluating DR_CNN-LSTM_SEED2 against population in Overcooked-CrampedRoom6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [12.199999809265137, 10.399999618530273, 76.4000015258789, 69.5999984741211, 99.4000015258789, 98.19999694824219, 14.799999237060547, 13.199999809265137, 76.19999694824219, 72.19999694824219, 95.5999984741211, 102.39999389648438, 18.799999237060547, 12.399999618530273, 95.5999984741211, 97.4000015258789, 100.0, 97.19999694824219, 19.19999885559082, 16.0, 85.79999542236328, 84.5999984741211, 86.19999694824219, 84.5999984741211, 7.599999904632568, 10.59999942779541, 75.19999694824219, 68.79999542236328, 93.0, 97.5999984741211, 20.600000381469727, 19.19999885559082, 82.4000015258789, 85.4000015258789, 102.39999389648438, 105.79999542236328, 19.19999885559082, 15.0, 77.0, 79.79999542236328, 101.79999542236328, 104.39999389648438, 21.600000381469727, 18.0, 87.0, 84.4000015258789, 100.4000015258789, 92.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [12.199999809265137, 10.399999618530273, 14.799999237060547, 13.199999809265137, 18.799999237060547, 12.399999618530273, 19.19999885559082, 16.0, 7.599999904632568, 10.59999942779541, 20.600000381469727, 19.19999885559082, 19.19999885559082, 15.0, 21.600000381469727, 18.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [76.4000015258789, 69.5999984741211, 76.19999694824219, 72.19999694824219, 95.5999984741211, 97.4000015258789, 85.79999542236328, 84.5999984741211, 75.19999694824219, 68.79999542236328, 82.4000015258789, 85.4000015258789, 77.0, 79.79999542236328, 87.0, 84.4000015258789] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [99.4000015258789, 98.19999694824219, 95.5999984741211, 102.39999389648438, 100.0, 97.19999694824219, 86.19999694824219, 84.5999984741211, 93.0, 97.5999984741211, 102.39999389648438, 105.79999542236328, 101.79999542236328, 104.39999389648438, 100.4000015258789, 92.79999542236328] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 64.76+/- 5.248 (max: 105.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 97.61+/- 1.502 (max: 105.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 15.55+/- 1.042 (max: 21.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 81.11+/- 2.08 (max: 97.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 24.14+/- 1.036 (max: 36.09) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 29.95+/- 0.7254 (max: 36.09) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 14.76+/- 0.4026 (max: 17.85) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 27.72+/- 0.6638 (max: 32.16) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6962+/- 0.05406 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9725+/- 0.005951 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1769+/- 0.01944 (max: 0.29) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9394+/- 0.00704 (max: 0.99) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 7.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 84.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 7.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 68.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 11.93 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 25.23 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 11.93 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 23.09 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.06 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.92 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.06 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9 | +------------------------------------------------------------------------------------------------- + + + + + + + + + + +Evaluating DR_CNN-LSTM_SEED3 against population in Overcooked-CoordRing6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +k eval/a0:test_return:Overcooked-CoordRing6_9, v [3.0, 4.799999713897705, 18.600000381469727, 15.199999809265137, 24.399999618530273, 21.799999237060547, 14.0, 14.799999237060547, 36.20000076293945, 35.20000076293945, 26.0, 27.19999885559082, 12.0, 12.0, 26.799999237060547, 26.0, 23.799999237060547, 23.600000381469727, 15.799999237060547, 10.59999942779541, 20.19999885559082, 21.600000381469727, 22.799999237060547, 18.600000381469727, 17.19999885559082, 18.799999237060547, 27.0, 26.0, 25.799999237060547, 26.399999618530273, 15.799999237060547, 14.0, 26.599998474121094, 29.799999237060547, 34.0, 30.0, 4.599999904632568, 3.1999998092651367, 24.19999885559082, 26.599998474121094, 29.0, 29.0, 11.800000190734863, 15.0, 21.19999885559082, 21.399999618530273, 25.19999885559082, 20.799999237060547] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [3.0, 4.799999713897705, 14.0, 14.799999237060547, 12.0, 12.0, 15.799999237060547, 10.59999942779541, 17.19999885559082, 18.799999237060547, 15.799999237060547, 14.0, 4.599999904632568, 3.1999998092651367, 11.800000190734863, 15.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [18.600000381469727, 15.199999809265137, 36.20000076293945, 35.20000076293945, 26.799999237060547, 26.0, 20.19999885559082, 21.600000381469727, 27.0, 26.0, 26.599998474121094, 29.799999237060547, 24.19999885559082, 26.599998474121094, 21.19999885559082, 21.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [24.399999618530273, 21.799999237060547, 26.0, 27.19999885559082, 23.799999237060547, 23.600000381469727, 22.799999237060547, 18.600000381469727, 25.799999237060547, 26.399999618530273, 34.0, 30.0, 29.0, 29.0, 25.19999885559082, 20.799999237060547] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 20.8+/- 1.165 (max: 36.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 25.52+/- 0.9564 (max: 34.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 11.71+/- 1.279 (max: 18.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 25.16+/- 1.394 (max: 36.2) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 13.49+/- 0.3268 (max: 17.55) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 14.33+/- 0.5036 (max: 17.55) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.02+/- 0.6642 (max: 15.18) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 14.11+/- 0.3018 (max: 16.98) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.2794+/- 0.0261 (max: 0.73) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.3756+/- 0.02985 (max: 0.65) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.09187+/- 0.01713 (max: 0.2) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.3706+/- 0.04128 (max: 0.73) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 3.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 18.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 3.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 15.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 7.141 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.14 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 7.141 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 12.98 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.21 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.15 | +----------------------------------------------------------------------------------------------- +Evaluating DR_CNN-LSTM_SEED3 against population in Overcooked-ForcedCoord6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [1.399999976158142, 0.3999999761581421, 2.0, 0.0, 1.0, 0.0, 2.200000047683716, 0.19999998807907104, 2.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.0, 0.0, 2.200000047683716, 0.0, 1.399999976158142, 0.0, 0.7999999523162842, 0.0, 2.0, 0.0, 2.0, 0.0, 0.5999999642372131, 0.19999998807907104, 1.1999999284744263, 0.19999998807907104, 1.0, 0.19999998807907104, 0.19999998807907104, 0.0, 0.3999999761581421, 0.0, 0.7999999523162842, 0.19999998807907104, 0.3999999761581421, 0.0, 1.1999999284744263, 0.0, 2.3999998569488525, 0.0, 4.199999809265137, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [1.399999976158142, 0.3999999761581421, 2.200000047683716, 0.19999998807907104, 1.1999999284744263, 0.0, 2.200000047683716, 0.0, 2.0, 0.0, 1.1999999284744263, 0.19999998807907104, 0.3999999761581421, 0.0, 1.1999999284744263, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [2.0, 0.0, 2.0, 0.19999998807907104, 1.1999999284744263, 0.0, 1.399999976158142, 0.0, 2.0, 0.0, 1.0, 0.19999998807907104, 0.7999999523162842, 0.19999998807907104, 2.3999998569488525, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [1.0, 0.0, 1.0, 0.19999998807907104, 1.0, 0.0, 0.7999999523162842, 0.0, 0.5999999642372131, 0.19999998807907104, 0.19999998807907104, 0.0, 0.3999999761581421, 0.0, 4.199999809265137, 0.0] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.7417+/- 0.1311 (max: 4.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.6+/- 0.2595 (max: 4.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.7875+/- 0.2085 (max: 2.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.8375+/- 0.2208 (max: 2.4) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 2.779+/- 0.3499 (max: 8.146) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 2.333+/- 0.5881 (max: 8.146) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 2.965+/- 0.6155 (max: 6.258) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 3.037+/- 0.6372 (max: 6.499) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating DR_CNN-LSTM_SEED3 against population in Overcooked-CounterCircuit6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [8.0, 5.400000095367432, 14.199999809265137, 9.800000190734863, 13.399999618530273, 13.399999618530273, 5.199999809265137, 7.0, 12.0, 6.599999904632568, 3.0, 1.399999976158142, 5.400000095367432, 3.3999998569488525, 12.799999237060547, 10.399999618530273, 12.399999618530273, 8.800000190734863, 5.400000095367432, 6.799999713897705, 17.799999237060547, 17.19999885559082, 11.199999809265137, 4.0, 6.799999713897705, 7.0, 21.399999618530273, 18.600000381469727, 9.199999809265137, 7.599999904632568, 8.800000190734863, 7.799999713897705, 14.799999237060547, 10.800000190734863, 9.800000190734863, 7.399999618530273, 7.0, 5.0, 15.799999237060547, 12.199999809265137, 7.0, 3.5999999046325684, 6.399999618530273, 5.799999713897705, 12.799999237060547, 10.800000190734863, 8.399999618530273, 4.199999809265137] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [8.0, 5.400000095367432, 5.199999809265137, 7.0, 5.400000095367432, 3.3999998569488525, 5.400000095367432, 6.799999713897705, 6.799999713897705, 7.0, 8.800000190734863, 7.799999713897705, 7.0, 5.0, 6.399999618530273, 5.799999713897705] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [14.199999809265137, 9.800000190734863, 12.0, 6.599999904632568, 12.799999237060547, 10.399999618530273, 17.799999237060547, 17.19999885559082, 21.399999618530273, 18.600000381469727, 14.799999237060547, 10.800000190734863, 15.799999237060547, 12.199999809265137, 12.799999237060547, 10.800000190734863] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [13.399999618530273, 13.399999618530273, 3.0, 1.399999976158142, 12.399999618530273, 8.800000190734863, 11.199999809265137, 4.0, 9.199999809265137, 7.599999904632568, 9.800000190734863, 7.399999618530273, 7.0, 3.5999999046325684, 8.399999618530273, 4.199999809265137] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 9.25+/- 0.6438 (max: 21.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 7.8+/- 0.9413 (max: 13.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 6.325+/- 0.3376 (max: 8.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 13.62+/- 0.9536 (max: 21.4) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 11.14+/- 0.3932 (max: 17.55) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 9.623+/- 0.5015 (max: 13.28) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 9.849+/- 0.2929 (max: 12.11) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 13.96+/- 0.5662 (max: 17.55) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.05396+/- 0.01054 (max: 0.31) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.01812+/- 0.007704 (max: 0.11) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.01687+/- 0.004806 (max: 0.06) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.1269+/- 0.02073 (max: 0.31) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 1.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 1.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 3.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 6.6 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 5.103 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 5.103 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 7.513 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 9.82 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating DR_CNN-LSTM_SEED3 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.7999999523162842, 3.799999952316284, 0.19999998807907104, 12.399999618530273, 1.399999976158142, 6.399999618530273, 0.0, 9.399999618530273, 0.5999999642372131, 39.0, 1.0, 2.799999952316284, 0.3999999761581421, 38.79999923706055, 0.0, 58.79999923706055, 1.5999999046325684, 10.59999942779541, 0.5999999642372131, 14.0, 0.19999998807907104, 10.199999809265137, 0.7999999523162842, 2.0, 0.19999998807907104, 3.5999999046325684, 0.19999998807907104, 18.600000381469727, 0.0, 6.0, 0.5999999642372131, 4.599999904632568, 0.0, 31.19999885559082, 0.19999998807907104, 4.799999713897705, 0.7999999523162842, 4.0, 0.0, 18.600000381469727, 1.7999999523162842, 4.799999713897705, 0.3999999761581421, 5.0, 0.0, 20.799999237060547, 0.7999999523162842, 3.5999999046325684] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.7999999523162842, 3.799999952316284, 0.0, 9.399999618530273, 0.3999999761581421, 38.79999923706055, 0.5999999642372131, 14.0, 0.19999998807907104, 3.5999999046325684, 0.5999999642372131, 4.599999904632568, 0.7999999523162842, 4.0, 0.3999999761581421, 5.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.19999998807907104, 12.399999618530273, 0.5999999642372131, 39.0, 0.0, 58.79999923706055, 0.19999998807907104, 10.199999809265137, 0.19999998807907104, 18.600000381469727, 0.0, 31.19999885559082, 0.0, 18.600000381469727, 0.0, 20.799999237060547] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [1.399999976158142, 6.399999618530273, 1.0, 2.799999952316284, 1.5999999046325684, 10.59999942779541, 0.7999999523162842, 2.0, 0.0, 6.0, 0.19999998807907104, 4.799999713897705, 1.7999999523162842, 4.799999713897705, 0.7999999523162842, 3.5999999046325684] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 7.217+/- 1.76 (max: 58.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 3.037+/- 0.7149 (max: 10.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 5.437+/- 2.423 (max: 38.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 13.17+/- 4.358 (max: 58.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 8.597+/- 1.147 (max: 31.6) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 6.771+/- 0.8763 (max: 12.48) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 7.538+/- 1.689 (max: 27.4) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 11.48+/- 2.808 (max: 31.6) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.08604+/- 0.0262 (max: 0.78) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.01562+/- 0.004279 (max: 0.05) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.05688+/- 0.03871 (max: 0.61) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1856+/- 0.06235 (max: 0.78) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating DR_CNN-LSTM_SEED3 against population in Overcooked-CrampedRoom6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [11.59999942779541, 11.199999809265137, 78.0, 72.19999694824219, 98.79999542236328, 101.0, 18.399999618530273, 15.199999809265137, 79.4000015258789, 76.5999984741211, 100.79999542236328, 95.5999984741211, 21.0, 19.600000381469727, 100.79999542236328, 101.0, 99.0, 100.19999694824219, 17.600000381469727, 20.600000381469727, 92.0, 91.4000015258789, 86.5999984741211, 84.5999984741211, 10.199999809265137, 11.800000190734863, 78.19999694824219, 76.19999694824219, 98.5999984741211, 98.0, 21.799999237060547, 19.799999237060547, 88.4000015258789, 86.5999984741211, 101.19999694824219, 102.79999542236328, 21.399999618530273, 18.399999618530273, 87.79999542236328, 86.19999694824219, 103.79999542236328, 104.0, 21.399999618530273, 19.799999237060547, 96.4000015258789, 94.5999984741211, 96.5999984741211, 94.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [11.59999942779541, 11.199999809265137, 18.399999618530273, 15.199999809265137, 21.0, 19.600000381469727, 17.600000381469727, 20.600000381469727, 10.199999809265137, 11.800000190734863, 21.799999237060547, 19.799999237060547, 21.399999618530273, 18.399999618530273, 21.399999618530273, 19.799999237060547] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [78.0, 72.19999694824219, 79.4000015258789, 76.5999984741211, 100.79999542236328, 101.0, 92.0, 91.4000015258789, 78.19999694824219, 76.19999694824219, 88.4000015258789, 86.5999984741211, 87.79999542236328, 86.19999694824219, 96.4000015258789, 94.5999984741211] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [98.79999542236328, 101.0, 100.79999542236328, 95.5999984741211, 99.0, 100.19999694824219, 86.5999984741211, 84.5999984741211, 98.5999984741211, 98.0, 101.19999694824219, 102.79999542236328, 103.79999542236328, 104.0, 96.5999984741211, 94.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 67.32+/- 5.266 (max: 104.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 97.85+/- 1.384 (max: 104.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 17.49+/- 1.027 (max: 21.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 86.61+/- 2.271 (max: 101.0) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 24.09+/- 1.038 (max: 36.21) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 31.33+/- 0.6372 (max: 36.21) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 15.21+/- 0.3948 (max: 17.06) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 25.72+/- 0.7912 (max: 31.25) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.7142+/- 0.05204 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9631+/- 0.004155 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.215+/- 0.022 (max: 0.34) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9644+/- 0.006517 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 10.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 84.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 10.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 72.2 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 12.71 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 26.94 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 12.71 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 20.88 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.08 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.94 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.08 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.92 | +------------------------------------------------------------------------------------------------- + + + + + + + +Evaluating PLR_CNN-LSTM_SEED1 against population in Overcooked-CoordRing6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +k eval/a0:test_return:Overcooked-CoordRing6_9, v [4.799999713897705, 3.3999998569488525, 18.399999618530273, 14.799999237060547, 21.399999618530273, 21.799999237060547, 12.799999237060547, 11.399999618530273, 26.399999618530273, 29.799999237060547, 21.0, 23.399999618530273, 11.0, 11.800000190734863, 21.799999237060547, 22.600000381469727, 24.19999885559082, 22.399999618530273, 13.0, 10.399999618530273, 16.0, 16.799999237060547, 18.600000381469727, 18.399999618530273, 15.399999618530273, 17.19999885559082, 18.19999885559082, 21.399999618530273, 21.399999618530273, 24.399999618530273, 11.0, 13.199999809265137, 20.19999885559082, 20.600000381469727, 31.399999618530273, 34.0, 2.799999952316284, 5.400000095367432, 16.600000381469727, 21.0, 27.599998474121094, 30.799999237060547, 14.399999618530273, 16.19999885559082, 14.399999618530273, 16.0, 22.600000381469727, 26.599998474121094] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [4.799999713897705, 3.3999998569488525, 12.799999237060547, 11.399999618530273, 11.0, 11.800000190734863, 13.0, 10.399999618530273, 15.399999618530273, 17.19999885559082, 11.0, 13.199999809265137, 2.799999952316284, 5.400000095367432, 14.399999618530273, 16.19999885559082] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [18.399999618530273, 14.799999237060547, 26.399999618530273, 29.799999237060547, 21.799999237060547, 22.600000381469727, 16.0, 16.799999237060547, 18.19999885559082, 21.399999618530273, 20.19999885559082, 20.600000381469727, 16.600000381469727, 21.0, 14.399999618530273, 16.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [21.399999618530273, 21.799999237060547, 21.0, 23.399999618530273, 24.19999885559082, 22.399999618530273, 18.600000381469727, 18.399999618530273, 21.399999618530273, 24.399999618530273, 31.399999618530273, 34.0, 27.599998474121094, 30.799999237060547, 22.600000381469727, 26.599998474121094] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 18.32+/- 1.028 (max: 34.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 24.37+/- 1.139 (max: 34.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 10.89+/- 1.126 (max: 17.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 19.69+/- 1.054 (max: 29.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 12.25+/- 0.3323 (max: 15.94) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.57+/- 0.6373 (max: 15.94) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 11.98+/- 0.6407 (max: 15.5) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 13.21+/- 0.3513 (max: 15.61) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.196+/- 0.02201 (max: 0.68) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.2975+/- 0.04513 (max: 0.68) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.08312+/- 0.0165 (max: 0.22) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.2075+/- 0.02643 (max: 0.44) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 2.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 18.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 14.4 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 6.94 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 8.127 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.94 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 10.98 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.11 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.05 | +----------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-LSTM_SEED1 against population in Overcooked-ForcedCoord6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 + +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [1.0, 0.0, 1.1999999284744263, 0.0, 0.5999999642372131, 0.0, 1.1999999284744263, 0.3999999761581421, 1.7999999523162842, 0.7999999523162842, 0.3999999761581421, 0.0, 0.5999999642372131, 0.3999999761581421, 1.0, 0.5999999642372131, 0.19999998807907104, 0.19999998807907104, 0.7999999523162842, 0.0, 1.7999999523162842, 0.0, 1.0, 0.19999998807907104, 1.0, 0.0, 0.7999999523162842, 0.0, 1.7999999523162842, 0.0, 1.1999999284744263, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.3999999761581421, 0.3999999761581421, 0.0, 0.7999999523162842, 0.19999998807907104, 0.5999999642372131, 0.0, 1.5999999046325684, 0.19999998807907104, 1.5999999046325684, 0.7999999523162842, 2.5999999046325684, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [1.0, 0.0, 1.1999999284744263, 0.3999999761581421, 0.5999999642372131, 0.3999999761581421, 0.7999999523162842, 0.0, 1.0, 0.0, 1.1999999284744263, 0.0, 0.3999999761581421, 0.0, 1.5999999046325684, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [1.1999999284744263, 0.0, 1.7999999523162842, 0.7999999523162842, 1.0, 0.5999999642372131, 1.7999999523162842, 0.0, 0.7999999523162842, 0.0, 0.5999999642372131, 0.0, 0.7999999523162842, 0.19999998807907104, 1.5999999046325684, 0.7999999523162842] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.5999999642372131, 0.0, 0.3999999761581421, 0.0, 0.19999998807907104, 0.19999998807907104, 1.0, 0.19999998807907104, 1.7999999523162842, 0.0, 0.5999999642372131, 0.3999999761581421, 0.5999999642372131, 0.0, 2.5999999046325684, 0.19999998807907104] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.6167+/- 0.08949 (max: 2.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.55+/- 0.1794 (max: 2.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.55+/- 0.131 (max: 1.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.75+/- 0.1555 (max: 1.8) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 2.759+/- 0.2904 (max: 6.726) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 2.538+/- 0.5017 (max: 6.726) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 2.585+/- 0.5004 (max: 5.426) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 3.154+/- 0.5243 (max: 5.724) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-LSTM_SEED1 against population in Overcooked-CounterCircuit6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [12.59999942779541, 10.399999618530273, 20.19999885559082, 15.399999618530273, 16.0, 16.0, 8.199999809265137, 6.0, 14.0, 13.399999618530273, 3.3999998569488525, 1.399999976158142, 5.199999809265137, 2.200000047683716, 12.799999237060547, 7.0, 13.399999618530273, 5.599999904632568, 10.800000190734863, 6.0, 24.19999885559082, 19.0, 9.800000190734863, 4.799999713897705, 11.0, 8.199999809265137, 21.19999885559082, 18.0, 10.800000190734863, 3.799999952316284, 14.0, 8.800000190734863, 23.799999237060547, 14.0, 11.0, 6.199999809265137, 8.800000190734863, 7.599999904632568, 20.19999885559082, 17.19999885559082, 9.0, 4.0, 14.0, 9.59999942779541, 25.599998474121094, 22.799999237060547, 9.199999809265137, 3.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [12.59999942779541, 10.399999618530273, 8.199999809265137, 6.0, 5.199999809265137, 2.200000047683716, 10.800000190734863, 6.0, 11.0, 8.199999809265137, 14.0, 8.800000190734863, 8.800000190734863, 7.599999904632568, 14.0, 9.59999942779541] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [20.19999885559082, 15.399999618530273, 14.0, 13.399999618530273, 12.799999237060547, 7.0, 24.19999885559082, 19.0, 21.19999885559082, 18.0, 23.799999237060547, 14.0, 20.19999885559082, 17.19999885559082, 25.599998474121094, 22.799999237060547] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [16.0, 16.0, 3.3999998569488525, 1.399999976158142, 13.399999618530273, 5.599999904632568, 9.800000190734863, 4.799999713897705, 10.800000190734863, 3.799999952316284, 11.0, 6.199999809265137, 9.0, 4.0, 9.199999809265137, 3.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 11.66+/- 0.905 (max: 25.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 7.962+/- 1.159 (max: 16.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 8.962+/- 0.8023 (max: 14.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 18.05+/- 1.262 (max: 25.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 11.84+/- 0.494 (max: 21.18) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 9.465+/- 0.4914 (max: 12.33) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 10.6+/- 0.4023 (max: 13.46) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 15.47+/- 0.7308 (max: 21.18) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.09479+/- 0.01766 (max: 0.44) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.02187+/- 0.007917 (max: 0.11) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.03187+/- 0.007428 (max: 0.11) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.2306+/- 0.03104 (max: 0.44) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 1.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 1.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 2.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 7.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 5.103 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 5.103 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 6.258 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 9.539 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-LSTM_SEED1 against population in Overcooked-AsymmAdvantages6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.3999999761581421, 3.1999998092651367, 0.7999999523162842, 10.0, 1.0, 2.3999998569488525, 0.7999999523162842, 6.0, 1.0, 39.20000076293945, 0.3999999761581421, 2.3999998569488525, 0.7999999523162842, 32.79999923706055, 1.399999976158142, 56.79999923706055, 1.1999999284744263, 10.59999942779541, 0.5999999642372131, 14.0, 1.0, 8.0, 1.1999999284744263, 3.3999998569488525, 0.3999999761581421, 3.0, 1.0, 16.799999237060547, 1.1999999284744263, 4.799999713897705, 0.3999999761581421, 4.799999713897705, 1.399999976158142, 22.19999885559082, 0.7999999523162842, 2.5999999046325684, 0.19999998807907104, 2.3999998569488525, 0.7999999523162842, 14.399999618530273, 1.399999976158142, 5.799999713897705, 0.19999998807907104, 4.799999713897705, 0.19999998807907104, 18.0, 1.0, 2.5999999046325684] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.3999999761581421, 3.1999998092651367, 0.7999999523162842, 6.0, 0.7999999523162842, 32.79999923706055, 0.5999999642372131, 14.0, 0.3999999761581421, 3.0, 0.3999999761581421, 4.799999713897705, 0.19999998807907104, 2.3999998569488525, 0.19999998807907104, 4.799999713897705] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.7999999523162842, 10.0, 1.0, 39.20000076293945, 1.399999976158142, 56.79999923706055, 1.0, 8.0, 1.0, 16.799999237060547, 1.399999976158142, 22.19999885559082, 0.7999999523162842, 14.399999618530273, 0.19999998807907104, 18.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [1.0, 2.3999998569488525, 0.3999999761581421, 2.3999998569488525, 1.1999999284744263, 10.59999942779541, 1.1999999284744263, 3.3999998569488525, 1.1999999284744263, 4.799999713897705, 0.7999999523162842, 2.5999999046325684, 1.399999976158142, 5.799999713897705, 1.0, 2.5999999046325684] +-------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 6.471+/- 1.607 (max: 56.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 2.675+/- 0.6462 (max: 10.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 4.675+/- 2.073 (max: 32.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 12.06+/- 4.038 (max: 56.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 8.634+/- 1.013 (max: 31.46) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 6.425+/- 0.6483 (max: 12.79) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 7.42+/- 1.666 (max: 27.79) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 12.06+/- 2.292 (max: 31.46) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.07125+/- 0.02422 (max: 0.79) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.008125+/- 0.005018 (max: 0.08) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.04562+/- 0.02999 (max: 0.47) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.16+/- 0.06126 (max: 0.79) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.4 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.2 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 2.8 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 1.99 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +-------------------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-LSTM_SEED1 against population in Overcooked-CrampedRoom6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [9.0, 9.399999618530273, 70.79999542236328, 65.19999694824219, 87.79999542236328, 91.19999694824219, 12.0, 16.600000381469727, 71.0, 66.79999542236328, 100.4000015258789, 99.4000015258789, 15.59999942779541, 15.199999809265137, 86.0, 88.4000015258789, 94.4000015258789, 93.79999542236328, 15.399999618530273, 14.399999618530273, 83.19999694824219, 80.5999984741211, 66.19999694824219, 71.0, 8.399999618530273, 8.800000190734863, 67.5999984741211, 66.19999694824219, 87.4000015258789, 96.19999694824219, 17.19999885559082, 13.799999237060547, 87.19999694824219, 81.5999984741211, 104.19999694824219, 102.39999389648438, 16.399999618530273, 14.199999809265137, 83.79999542236328, 75.79999542236328, 95.0, 95.0, 16.399999618530273, 15.59999942779541, 83.79999542236328, 79.79999542236328, 91.5999984741211, 94.19999694824219] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [9.0, 9.399999618530273, 12.0, 16.600000381469727, 15.59999942779541, 15.199999809265137, 15.399999618530273, 14.399999618530273, 8.399999618530273, 8.800000190734863, 17.19999885559082, 13.799999237060547, 16.399999618530273, 14.199999809265137, 16.399999618530273, 15.59999942779541] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [70.79999542236328, 65.19999694824219, 71.0, 66.79999542236328, 86.0, 88.4000015258789, 83.19999694824219, 80.5999984741211, 67.5999984741211, 66.19999694824219, 87.19999694824219, 81.5999984741211, 83.79999542236328, 75.79999542236328, 83.79999542236328, 79.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [87.79999542236328, 91.19999694824219, 100.4000015258789, 99.4000015258789, 94.4000015258789, 93.79999542236328, 66.19999694824219, 71.0, 87.4000015258789, 96.19999694824219, 104.19999694824219, 102.39999389648438, 95.0, 95.0, 91.5999984741211, 94.19999694824219] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 60.97+/- 5.077 (max: 104.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 91.89+/- 2.564 (max: 104.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 13.65+/- 0.7743 (max: 17.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 77.36+/- 2.051 (max: 88.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 23.42+/- 1.069 (max: 35.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 28.62+/- 0.6682 (max: 34.84) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.64+/- 0.3778 (max: 16.25) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 27.99+/- 0.76 (max: 35.4) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6687+/- 0.05689 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9581+/- 0.0113 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1212+/- 0.01363 (max: 0.2) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9269+/- 0.00982 (max: 0.98) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 8.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 66.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 8.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 65.2 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 11.02 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 23.68 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 11.02 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 23.61 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.81 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.85 | +------------------------------------------------------------------------------------------------- + + + + + + + + +Evaluating PLR_CNN-LSTM_SEED2 against population in Overcooked-CoordRing6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [1.7999999523162842, 2.5999999046325684, 5.0, 5.0, 2.799999952316284, 2.5999999046325684, 6.199999809265137, 5.799999713897705, 8.399999618530273, 10.0, 5.199999809265137, 3.799999952316284, 6.199999809265137, 7.599999904632568, 7.399999618530273, 6.199999809265137, 5.799999713897705, 5.599999904632568, 10.399999618530273, 6.799999713897705, 7.799999713897705, 7.599999904632568, 5.400000095367432, 4.199999809265137, 5.599999904632568, 4.799999713897705, 1.399999976158142, 0.7999999523162842, 3.799999952316284, 5.0, 7.0, 9.399999618530273, 12.799999237060547, 14.399999618530273, 12.799999237060547, 10.59999942779541, 0.7999999523162842, 0.5999999642372131, 6.599999904632568, 5.599999904632568, 11.199999809265137, 14.59999942779541, 9.199999809265137, 9.59999942779541, 5.599999904632568, 7.399999618530273, 5.599999904632568, 9.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [1.7999999523162842, 2.5999999046325684, 6.199999809265137, 5.799999713897705, 6.199999809265137, 7.599999904632568, 10.399999618530273, 6.799999713897705, 5.599999904632568, 4.799999713897705, 7.0, 9.399999618530273, 0.7999999523162842, 0.5999999642372131, 9.199999809265137, 9.59999942779541] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [5.0, 5.0, 8.399999618530273, 10.0, 7.399999618530273, 6.199999809265137, 7.799999713897705, 7.599999904632568, 1.399999976158142, 0.7999999523162842, 12.799999237060547, 14.399999618530273, 6.599999904632568, 5.599999904632568, 5.599999904632568, 7.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [2.799999952316284, 2.5999999046325684, 5.199999809265137, 3.799999952316284, 5.799999713897705, 5.599999904632568, 5.400000095367432, 4.199999809265137, 3.799999952316284, 5.0, 12.799999237060547, 10.59999942779541, 11.199999809265137, 14.59999942779541, 5.599999904632568, 9.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 6.55+/- 0.4913 (max: 14.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 6.75+/- 0.9253 (max: 14.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 5.9+/- 0.778 (max: 10.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 7.0+/- 0.8737 (max: 14.4) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 9.51+/- 0.3464 (max: 14.57) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 9.214+/- 0.3826 (max: 11.77) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 9.459+/- 0.7367 (max: 12.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 9.858+/- 0.6515 (max: 14.57) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0225+/- 0.004371 (max: 0.14) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.01312+/- 0.005456 (max: 0.07) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.02687+/- 0.006565 (max: 0.08) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0275+/- 0.009895 (max: 0.14) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 0.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 2.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 0.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 0.8 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 3.412 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 6.726 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 3.412 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 3.919 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-LSTM_SEED2 against population in Overcooked-ForcedCoord6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.7999999523162842, 0.3999999761581421, 0.19999998807907104, 1.0, 0.5999999642372131, 1.0, 1.0, 1.0, 0.7999999523162842, 1.1999999284744263, 0.19999998807907104, 0.19999998807907104, 1.399999976158142, 0.3999999761581421, 0.7999999523162842, 1.5999999046325684, 0.19999998807907104, 2.0, 0.7999999523162842, 0.7999999523162842, 0.7999999523162842, 0.5999999642372131, 0.3999999761581421, 1.399999976158142, 1.1999999284744263, 0.7999999523162842, 1.0, 1.7999999523162842, 0.7999999523162842, 1.399999976158142, 0.5999999642372131, 0.19999998807907104, 0.7999999523162842, 0.5999999642372131, 0.19999998807907104, 0.5999999642372131, 0.19999998807907104, 0.5999999642372131, 0.3999999761581421, 0.7999999523162842, 0.0, 0.3999999761581421, 1.399999976158142, 0.0, 0.5999999642372131, 0.0, 1.399999976158142, 0.5999999642372131] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.7999999523162842, 0.3999999761581421, 1.0, 1.0, 1.399999976158142, 0.3999999761581421, 0.7999999523162842, 0.7999999523162842, 1.1999999284744263, 0.7999999523162842, 0.5999999642372131, 0.19999998807907104, 0.19999998807907104, 0.5999999642372131, 1.399999976158142, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.19999998807907104, 1.0, 0.7999999523162842, 1.1999999284744263, 0.7999999523162842, 1.5999999046325684, 0.7999999523162842, 0.5999999642372131, 1.0, 1.7999999523162842, 0.7999999523162842, 0.5999999642372131, 0.3999999761581421, 0.7999999523162842, 0.5999999642372131, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.5999999642372131, 1.0, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 2.0, 0.3999999761581421, 1.399999976158142, 0.7999999523162842, 1.399999976158142, 0.19999998807907104, 0.5999999642372131, 0.0, 0.3999999761581421, 1.399999976158142, 0.5999999642372131] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.75+/- 0.06932 (max: 2.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.7125+/- 0.1437 (max: 2.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.725+/- 0.1047 (max: 1.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.8125+/- 0.1147 (max: 1.8) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 3.508+/- 0.2014 (max: 6.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 3.336+/- 0.3917 (max: 6.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 3.485+/- 0.3334 (max: 5.103) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 3.702+/- 0.3351 (max: 5.724) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-LSTM_SEED2 against population in Overcooked-CounterCircuit6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [8.199999809265137, 8.199999809265137, 7.0, 5.400000095367432, 7.399999618530273, 4.599999904632568, 3.5999999046325684, 1.7999999523162842, 4.599999904632568, 4.199999809265137, 0.5999999642372131, 0.0, 0.19999998807907104, 0.19999998807907104, 2.0, 2.0, 2.0, 2.0, 4.199999809265137, 2.5999999046325684, 13.199999809265137, 9.800000190734863, 8.59999942779541, 2.799999952316284, 7.0, 5.0, 8.399999618530273, 4.599999904632568, 0.3999999761581421, 0.7999999523162842, 8.59999942779541, 6.199999809265137, 11.399999618530273, 3.799999952316284, 5.799999713897705, 2.200000047683716, 3.5999999046325684, 1.399999976158142, 6.199999809265137, 4.400000095367432, 4.0, 3.0, 8.800000190734863, 7.199999809265137, 9.399999618530273, 8.0, 10.0, 1.7999999523162842] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [8.199999809265137, 8.199999809265137, 3.5999999046325684, 1.7999999523162842, 0.19999998807907104, 0.19999998807907104, 4.199999809265137, 2.5999999046325684, 7.0, 5.0, 8.59999942779541, 6.199999809265137, 3.5999999046325684, 1.399999976158142, 8.800000190734863, 7.199999809265137] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [7.0, 5.400000095367432, 4.599999904632568, 4.199999809265137, 2.0, 2.0, 13.199999809265137, 9.800000190734863, 8.399999618530273, 4.599999904632568, 11.399999618530273, 3.799999952316284, 6.199999809265137, 4.400000095367432, 9.399999618530273, 8.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [7.399999618530273, 4.599999904632568, 0.5999999642372131, 0.0, 2.0, 2.0, 8.59999942779541, 2.799999952316284, 0.3999999761581421, 0.7999999523162842, 5.799999713897705, 2.200000047683716, 4.0, 3.0, 10.0, 1.7999999523162842] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 4.942+/- 0.4749 (max: 13.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 3.5+/- 0.7572 (max: 10.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 4.8+/- 0.7559 (max: 8.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 6.525+/- 0.8173 (max: 13.2) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 8.365+/- 0.4804 (max: 16.36) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 6.551+/- 0.7411 (max: 10.83) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 8.041+/- 0.763 (max: 12.03) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 10.5+/- 0.7122 (max: 16.36) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.01896+/- 0.004419 (max: 0.12) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0025+/- 0.001936 (max: 0.03) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.01125+/- 0.00407 (max: 0.06) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.04312+/- 0.01011 (max: 0.12) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 2.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 6.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-LSTM_SEED2 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [9.399999618530273, 2.3999998569488525, 14.199999809265137, 10.399999618530273, 12.0, 0.3999999761581421, 11.199999809265137, 5.400000095367432, 14.199999809265137, 38.79999923706055, 12.799999237060547, 0.0, 14.59999942779541, 38.599998474121094, 14.199999809265137, 56.39999771118164, 8.199999809265137, 8.59999942779541, 13.59999942779541, 16.19999885559082, 15.59999942779541, 4.199999809265137, 13.399999618530273, 1.399999976158142, 9.800000190734863, 1.5999999046325684, 15.0, 15.199999809265137, 15.799999237060547, 2.200000047683716, 10.199999809265137, 2.200000047683716, 15.399999618530273, 23.799999237060547, 10.0, 0.19999998807907104, 9.0, 2.0, 16.600000381469727, 13.799999237060547, 11.59999942779541, 4.0, 8.59999942779541, 4.799999713897705, 17.19999885559082, 18.0, 14.399999618530273, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [9.399999618530273, 2.3999998569488525, 11.199999809265137, 5.400000095367432, 14.59999942779541, 38.599998474121094, 13.59999942779541, 16.19999885559082, 9.800000190734863, 1.5999999046325684, 10.199999809265137, 2.200000047683716, 9.0, 2.0, 8.59999942779541, 4.799999713897705] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [14.199999809265137, 10.399999618530273, 14.199999809265137, 38.79999923706055, 14.199999809265137, 56.39999771118164, 15.59999942779541, 4.199999809265137, 15.0, 15.199999809265137, 15.399999618530273, 23.799999237060547, 16.600000381469727, 13.799999237060547, 17.19999885559082, 18.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [12.0, 0.3999999761581421, 12.799999237060547, 0.0, 8.199999809265137, 8.59999942779541, 13.399999618530273, 1.399999976158142, 15.799999237060547, 2.200000047683716, 10.0, 0.19999998807907104, 11.59999942779541, 4.0, 14.399999618530273, 0.0] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 12.03+/- 1.516 (max: 56.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 7.187+/- 1.467 (max: 15.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 9.975+/- 2.233 (max: 38.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 18.94+/- 3.074 (max: 56.4) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 13.36+/- 0.8972 (max: 30.38) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 9.721+/- 1.534 (max: 17.27) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 12.74+/- 1.377 (max: 26.42) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 17.62+/- 1.095 (max: 30.38) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.1352+/- 0.02321 (max: 0.76) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0625+/- 0.0167 (max: 0.19) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.1044+/- 0.03654 (max: 0.59) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.2387+/- 0.04818 (max: 0.76) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 1.6 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 4.2 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 6.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 9.075 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.02 | +----------------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-LSTM_SEED2 against population in Overcooked-CrampedRoom6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [7.399999618530273, 5.0, 70.0, 67.79999542236328, 82.5999984741211, 79.5999984741211, 12.399999618530273, 11.199999809265137, 67.5999984741211, 62.19999694824219, 73.4000015258789, 76.4000015258789, 10.399999618530273, 12.199999809265137, 74.4000015258789, 76.4000015258789, 59.79999923706055, 66.4000015258789, 12.199999809265137, 12.399999618530273, 72.19999694824219, 74.79999542236328, 58.19999694824219, 62.19999694824219, 6.199999809265137, 6.599999904632568, 60.19999694824219, 56.19999694824219, 69.4000015258789, 72.5999984741211, 13.399999618530273, 11.399999618530273, 68.79999542236328, 76.19999694824219, 84.0, 80.5999984741211, 12.399999618530273, 13.799999237060547, 70.19999694824219, 74.0, 70.5999984741211, 73.5999984741211, 14.399999618530273, 19.0, 80.5999984741211, 84.79999542236328, 74.79999542236328, 77.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [7.399999618530273, 5.0, 12.399999618530273, 11.199999809265137, 10.399999618530273, 12.199999809265137, 12.199999809265137, 12.399999618530273, 6.199999809265137, 6.599999904632568, 13.399999618530273, 11.399999618530273, 12.399999618530273, 13.799999237060547, 14.399999618530273, 19.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [70.0, 67.79999542236328, 67.5999984741211, 62.19999694824219, 74.4000015258789, 76.4000015258789, 72.19999694824219, 74.79999542236328, 60.19999694824219, 56.19999694824219, 68.79999542236328, 76.19999694824219, 70.19999694824219, 74.0, 80.5999984741211, 84.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [82.5999984741211, 79.5999984741211, 73.4000015258789, 76.4000015258789, 59.79999923706055, 66.4000015258789, 58.19999694824219, 62.19999694824219, 69.4000015258789, 72.5999984741211, 84.0, 80.5999984741211, 70.5999984741211, 73.5999984741211, 74.79999542236328, 77.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 51.62+/- 4.264 (max: 84.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 72.57+/- 1.953 (max: 84.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 11.27+/- 0.8865 (max: 19.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 71.02+/- 1.845 (max: 84.8) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 25.9+/- 1.375 (max: 37.64) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 31.05+/- 1.005 (max: 37.64) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.34+/- 0.4883 (max: 15.84) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 33.32+/- 0.7649 (max: 37.14) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6121+/- 0.05251 (max: 0.96) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.8756+/- 0.01605 (max: 0.96) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.11+/- 0.01528 (max: 0.25) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.8506+/- 0.01424 (max: 0.91) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 5.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 58.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 5.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 56.2 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 9.539 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 24.33 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 9.539 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 27.26 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.73 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.71 | +------------------------------------------------------------------------------------------------- + + + + + + + + + +Evaluating PLR_CNN-LSTM_SEED3 against population in Overcooked-CoordRing6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [4.0, 4.599999904632568, 19.0, 14.59999942779541, 17.799999237060547, 16.799999237060547, 13.799999237060547, 15.199999809265137, 31.799999237060547, 28.399999618530273, 17.399999618530273, 18.19999885559082, 12.399999618530273, 11.199999809265137, 20.799999237060547, 22.19999885559082, 19.0, 17.600000381469727, 11.0, 7.399999618530273, 20.0, 18.0, 15.399999618530273, 16.399999618530273, 17.0, 13.399999618530273, 21.0, 19.799999237060547, 17.0, 18.0, 10.0, 9.800000190734863, 20.19999885559082, 22.19999885559082, 29.599998474121094, 31.0, 3.5999999046325684, 3.3999998569488525, 13.799999237060547, 14.399999618530273, 20.600000381469727, 18.799999237060547, 5.599999904632568, 7.399999618530273, 9.0, 9.0, 11.399999618530273, 12.59999942779541] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [4.0, 4.599999904632568, 13.799999237060547, 15.199999809265137, 12.399999618530273, 11.199999809265137, 11.0, 7.399999618530273, 17.0, 13.399999618530273, 10.0, 9.800000190734863, 3.5999999046325684, 3.3999998569488525, 5.599999904632568, 7.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [19.0, 14.59999942779541, 31.799999237060547, 28.399999618530273, 20.799999237060547, 22.19999885559082, 20.0, 18.0, 21.0, 19.799999237060547, 20.19999885559082, 22.19999885559082, 13.799999237060547, 14.399999618530273, 9.0, 9.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [17.799999237060547, 16.799999237060547, 17.399999618530273, 18.19999885559082, 19.0, 17.600000381469727, 15.399999618530273, 16.399999618530273, 17.0, 18.0, 29.599998474121094, 31.0, 20.600000381469727, 18.799999237060547, 11.399999618530273, 12.59999942779541] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CoordRing6_9 | 15.66+/- 0.9834 (max: 31.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 18.6+/- 1.277 (max: 31.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 9.362+/- 1.092 (max: 17.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 19.01+/- 1.516 (max: 31.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 11.64+/- 0.3008 (max: 16.09) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.33+/- 0.532 (max: 16.09) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 10.92+/- 0.5703 (max: 14.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 12.66+/- 0.363 (max: 15.2) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.1352+/- 0.02203 (max: 0.63) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.1475+/- 0.04408 (max: 0.59) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.05062+/- 0.01355 (max: 0.17) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.2075+/- 0.04001 (max: 0.63) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 3.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 11.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 3.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 9.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 7.513 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 8.66 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 7.513 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 10.34 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.01 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.01 | +------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-LSTM_SEED3 against population in Overcooked-ForcedCoord6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [4.799999713897705, 0.0, 7.199999809265137, 0.19999998807907104, 2.3999998569488525, 0.5999999642372131, 4.400000095367432, 0.0, 4.400000095367432, 0.19999998807907104, 3.1999998092651367, 0.0, 3.5999999046325684, 0.0, 4.599999904632568, 0.19999998807907104, 3.799999952316284, 0.19999998807907104, 5.199999809265137, 0.0, 3.799999952316284, 0.0, 2.0, 0.0, 4.599999904632568, 0.0, 5.799999713897705, 0.0, 4.0, 0.0, 3.3999998569488525, 0.0, 4.599999904632568, 0.19999998807907104, 0.5999999642372131, 0.0, 2.5999999046325684, 0.0, 2.799999952316284, 0.0, 2.5999999046325684, 0.0, 4.799999713897705, 0.0, 7.199999809265137, 0.19999998807907104, 7.599999904632568, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [4.799999713897705, 0.0, 4.400000095367432, 0.0, 3.5999999046325684, 0.0, 5.199999809265137, 0.0, 4.599999904632568, 0.0, 3.3999998569488525, 0.0, 2.5999999046325684, 0.0, 4.799999713897705, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [7.199999809265137, 0.19999998807907104, 4.400000095367432, 0.19999998807907104, 4.599999904632568, 0.19999998807907104, 3.799999952316284, 0.0, 5.799999713897705, 0.0, 4.599999904632568, 0.19999998807907104, 2.799999952316284, 0.0, 7.199999809265137, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [2.3999998569488525, 0.5999999642372131, 3.1999998092651367, 0.0, 3.799999952316284, 0.19999998807907104, 2.0, 0.0, 4.0, 0.0, 0.5999999642372131, 0.0, 2.5999999046325684, 0.0, 7.599999904632568, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 2.121+/- 0.3435 (max: 7.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.687+/- 0.5407 (max: 7.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 2.087+/- 0.5598 (max: 5.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 2.587+/- 0.6903 (max: 7.2) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 4.279+/- 0.5598 (max: 9.708) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 3.841+/- 0.9043 (max: 9.708) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 4.058+/- 1.056 (max: 9.217) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 4.939+/- 0.9837 (max: 9.6) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0008333+/- 0.0004031 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.00125+/- 0.0008539 (max: 0.01) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-LSTM_SEED3 against population in Overcooked-CounterCircuit6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [1.0, 1.0, 2.799999952316284, 2.0, 1.0, 0.7999999523162842, 1.1999999284744263, 1.5999999046325684, 1.5999999046325684, 1.399999976158142, 0.0, 0.0, 3.0, 2.0, 13.799999237060547, 5.199999809265137, 7.399999618530273, 2.3999998569488525, 2.200000047683716, 0.7999999523162842, 8.800000190734863, 6.399999618530273, 1.7999999523162842, 0.5999999642372131, 2.799999952316284, 1.7999999523162842, 7.0, 7.0, 5.799999713897705, 1.1999999284744263, 2.5999999046325684, 0.19999998807907104, 4.0, 2.200000047683716, 3.0, 1.5999999046325684, 1.0, 0.3999999761581421, 1.399999976158142, 3.0, 0.19999998807907104, 0.0, 1.5999999046325684, 1.5999999046325684, 2.799999952316284, 1.0, 0.3999999761581421, 0.19999998807907104] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [1.0, 1.0, 1.1999999284744263, 1.5999999046325684, 3.0, 2.0, 2.200000047683716, 0.7999999523162842, 2.799999952316284, 1.7999999523162842, 2.5999999046325684, 0.19999998807907104, 1.0, 0.3999999761581421, 1.5999999046325684, 1.5999999046325684] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [2.799999952316284, 2.0, 1.5999999046325684, 1.399999976158142, 13.799999237060547, 5.199999809265137, 8.800000190734863, 6.399999618530273, 7.0, 7.0, 4.0, 2.200000047683716, 1.399999976158142, 3.0, 2.799999952316284, 1.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [1.0, 0.7999999523162842, 0.0, 0.0, 7.399999618530273, 2.3999998569488525, 1.7999999523162842, 0.5999999642372131, 5.799999713897705, 1.1999999284744263, 3.0, 1.5999999046325684, 0.19999998807907104, 0.0, 0.3999999761581421, 0.19999998807907104] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 2.533+/- 0.3906 (max: 13.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 1.65+/- 0.5365 (max: 7.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 1.55+/- 0.2062 (max: 3.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 4.4+/- 0.8687 (max: 13.8) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 5.7+/- 0.3884 (max: 10.7) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 4.221+/- 0.7807 (max: 10.45) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 5.152+/- 0.3571 (max: 7.141) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 7.727+/- 0.4913 (max: 10.7) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.003125+/- 0.001038 (max: 0.03) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.00125+/- 0.00125 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0075+/- 0.0025 (max: 0.03) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 1.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 4.359 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-LSTM_SEED3 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 16.0, 0.5999999642372131, 34.39999771118164, 0.19999998807907104, 35.599998474121094, 0.0, 15.59999942779541, 0.0, 53.39999771118164, 0.19999998807907104, 47.20000076293945, 0.0, 54.79999923706055, 0.0, 75.79999542236328, 0.19999998807907104, 36.39999771118164, 0.0, 36.39999771118164, 0.0, 31.599998474121094, 0.0, 27.799999237060547, 0.0, 18.799999237060547, 0.0, 38.79999923706055, 0.19999998807907104, 26.0, 0.0, 15.199999809265137, 0.0, 46.20000076293945, 0.0, 36.79999923706055, 0.19999998807907104, 15.399999618530273, 0.19999998807907104, 44.39999771118164, 0.0, 26.399999618530273, 0.19999998807907104, 19.0, 0.0, 41.20000076293945, 0.19999998807907104, 18.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 16.0, 0.0, 15.59999942779541, 0.0, 54.79999923706055, 0.0, 36.39999771118164, 0.0, 18.799999237060547, 0.0, 15.199999809265137, 0.19999998807907104, 15.399999618530273, 0.19999998807907104, 19.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.5999999642372131, 34.39999771118164, 0.0, 53.39999771118164, 0.0, 75.79999542236328, 0.0, 31.599998474121094, 0.0, 38.79999923706055, 0.0, 46.20000076293945, 0.19999998807907104, 44.39999771118164, 0.0, 41.20000076293945] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.19999998807907104, 35.599998474121094, 0.19999998807907104, 47.20000076293945, 0.19999998807907104, 36.39999771118164, 0.0, 27.799999237060547, 0.19999998807907104, 26.0, 0.0, 36.79999923706055, 0.0, 26.399999618530273, 0.19999998807907104, 18.0] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 16.95+/- 2.899 (max: 75.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 15.95+/- 4.366 (max: 47.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 11.97+/- 3.935 (max: 54.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 22.91+/- 6.354 (max: 75.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 10.71+/- 1.505 (max: 28.33) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 11.3+/- 2.652 (max: 25.39) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 8.996+/- 2.301 (max: 26.02) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 11.83+/- 2.942 (max: 28.33) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.2496+/- 0.04398 (max: 0.93) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.2381+/- 0.06728 (max: 0.73) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.1594+/- 0.05977 (max: 0.82) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.3512+/- 0.09405 (max: 0.93) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-LSTM_SEED3 against population in Overcooked-CrampedRoom6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [8.800000190734863, 5.799999713897705, 57.19999694824219, 55.79999923706055, 77.19999694824219, 77.4000015258789, 12.399999618530273, 13.799999237060547, 65.79999542236328, 60.599998474121094, 76.4000015258789, 76.4000015258789, 18.600000381469727, 16.600000381469727, 76.5999984741211, 77.4000015258789, 74.5999984741211, 77.19999694824219, 17.0, 17.399999618530273, 72.19999694824219, 68.5999984741211, 55.79999923706055, 59.79999923706055, 6.599999904632568, 9.800000190734863, 65.4000015258789, 61.39999771118164, 71.4000015258789, 71.19999694824219, 17.19999885559082, 17.799999237060547, 66.5999984741211, 66.79999542236328, 74.19999694824219, 76.4000015258789, 17.19999885559082, 17.799999237060547, 67.0, 61.19999694824219, 75.0, 75.5999984741211, 16.19999885559082, 15.799999237060547, 69.5999984741211, 68.79999542236328, 70.5999984741211, 70.5999984741211] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [8.800000190734863, 5.799999713897705, 12.399999618530273, 13.799999237060547, 18.600000381469727, 16.600000381469727, 17.0, 17.399999618530273, 6.599999904632568, 9.800000190734863, 17.19999885559082, 17.799999237060547, 17.19999885559082, 17.799999237060547, 16.19999885559082, 15.799999237060547] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [57.19999694824219, 55.79999923706055, 65.79999542236328, 60.599998474121094, 76.5999984741211, 77.4000015258789, 72.19999694824219, 68.5999984741211, 65.4000015258789, 61.39999771118164, 66.5999984741211, 66.79999542236328, 67.0, 61.19999694824219, 69.5999984741211, 68.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [77.19999694824219, 77.4000015258789, 76.4000015258789, 76.4000015258789, 74.5999984741211, 77.19999694824219, 55.79999923706055, 59.79999923706055, 71.4000015258789, 71.19999694824219, 74.19999694824219, 76.4000015258789, 75.0, 75.5999984741211, 70.5999984741211, 70.5999984741211] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 51.03+/- 3.889 (max: 77.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 72.49+/- 1.564 (max: 77.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 14.3+/- 1.068 (max: 18.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 66.31+/- 1.533 (max: 77.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 18.65+/- 0.6188 (max: 24.89) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 21.73+/- 0.3164 (max: 24.39) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.19+/- 0.3923 (max: 15.75) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 21.02+/- 0.5964 (max: 24.89) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6787+/- 0.05712 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.965+/- 0.008317 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1287+/- 0.01612 (max: 0.21) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9425+/- 0.008391 (max: 0.98) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 5.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 55.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 5.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 55.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.31 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 19.49 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.31 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 16.5 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.87 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.87 | +-------------------------------------------------------------------------------------------------- + + + + + + +Evaluating PAIRED_CNN-LSTM_SEED1 against population in Overcooked-CoordRing6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [4.199999809265137, 3.799999952316284, 16.399999618530273, 15.399999618530273, 22.600000381469727, 21.799999237060547, 11.59999942779541, 14.59999942779541, 29.19999885559082, 29.799999237060547, 21.799999237060547, 22.0, 13.59999942779541, 13.59999942779541, 22.799999237060547, 24.399999618530273, 22.0, 20.600000381469727, 12.199999809265137, 10.59999942779541, 14.0, 16.799999237060547, 16.0, 15.799999237060547, 13.0, 15.0, 17.799999237060547, 21.0, 21.600000381469727, 21.0, 13.399999618530273, 17.0, 21.399999618530273, 23.0, 35.599998474121094, 34.0, 2.5999999046325684, 4.599999904632568, 16.600000381469727, 21.19999885559082, 26.0, 28.399999618530273, 13.199999809265137, 13.399999618530273, 12.399999618530273, 15.399999618530273, 20.0, 23.600000381469727] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [4.199999809265137, 3.799999952316284, 11.59999942779541, 14.59999942779541, 13.59999942779541, 13.59999942779541, 12.199999809265137, 10.59999942779541, 13.0, 15.0, 13.399999618530273, 17.0, 2.5999999046325684, 4.599999904632568, 13.199999809265137, 13.399999618530273] +k eval/a1:test_return:Overcooked-CoordRing6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [16.399999618530273, 15.399999618530273, 29.19999885559082, 29.799999237060547, 22.799999237060547, 24.399999618530273, 14.0, 16.799999237060547, 17.799999237060547, 21.0, 21.399999618530273, 23.0, 16.600000381469727, 21.19999885559082, 12.399999618530273, 15.399999618530273] +k eval/a1:test_return:Overcooked-CoordRing6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [22.600000381469727, 21.799999237060547, 21.799999237060547, 22.0, 22.0, 20.600000381469727, 16.0, 15.799999237060547, 21.600000381469727, 21.0, 35.599998474121094, 34.0, 26.0, 28.399999618530273, 20.0, 23.600000381469727] +k eval/a1:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 18.06+/- 1.038 (max: 35.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 23.3+/- 1.359 (max: 35.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 11.02+/- 1.137 (max: 17.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 19.85+/- 1.285 (max: 29.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 12.26+/- 0.332 (max: 18.41) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.6+/- 0.5354 (max: 15.36) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 11.78+/- 0.6246 (max: 14.25) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 13.39+/- 0.4797 (max: 18.41) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.191+/- 0.02333 (max: 0.71) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.2656+/- 0.04847 (max: 0.71) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.08+/- 0.01393 (max: 0.17) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.2275+/- 0.03593 (max: 0.52) | +| eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 2.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 15.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 12.4 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 6.726 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 8.66 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.726 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.09 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.07 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.06 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-LSTM_SEED1 against population in Overcooked-ForcedCoord6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [2.799999952316284, 0.0, 3.5999999046325684, 0.0, 4.0, 0.0, 5.0, 0.0, 3.799999952316284, 0.0, 2.5999999046325684, 0.0, 2.3999998569488525, 0.0, 2.200000047683716, 0.0, 1.5999999046325684, 0.0, 3.1999998092651367, 0.0, 3.3999998569488525, 0.0, 1.7999999523162842, 0.0, 2.3999998569488525, 0.0, 2.200000047683716, 0.0, 2.3999998569488525, 0.0, 3.799999952316284, 0.0, 3.3999998569488525, 0.0, 1.399999976158142, 0.0, 0.7999999523162842, 0.0, 2.200000047683716, 0.0, 1.5999999046325684, 0.19999998807907104, 5.400000095367432, 0.0, 3.5999999046325684, 0.0, 5.599999904632568, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [2.799999952316284, 0.0, 5.0, 0.0, 2.3999998569488525, 0.0, 3.1999998092651367, 0.0, 2.3999998569488525, 0.0, 3.799999952316284, 0.0, 0.7999999523162842, 0.0, 5.400000095367432, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [3.5999999046325684, 0.0, 3.799999952316284, 0.0, 2.200000047683716, 0.0, 3.3999998569488525, 0.0, 2.200000047683716, 0.0, 3.3999998569488525, 0.0, 2.200000047683716, 0.0, 3.5999999046325684, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [4.0, 0.0, 2.5999999046325684, 0.0, 1.5999999046325684, 0.0, 1.7999999523162842, 0.0, 2.3999998569488525, 0.0, 1.399999976158142, 0.0, 1.5999999046325684, 0.19999998807907104, 5.599999904632568, 0.19999998807907104] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.492+/- 0.2494 (max: 5.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.337+/- 0.4166 (max: 5.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.612+/- 0.4884 (max: 5.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 1.525+/- 0.4123 (max: 3.8) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 3.539+/- 0.5121 (max: 9.415) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 3.519+/- 0.8357 (max: 9.415) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 3.536+/- 0.9511 (max: 8.879) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 3.563+/- 0.9284 (max: 7.846) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0002083+/- 0.0002083 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-LSTM_SEED1 against population in Overcooked-CounterCircuit6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [10.800000190734863, 7.599999904632568, 16.19999885559082, 11.800000190734863, 14.399999618530273, 13.199999809265137, 8.399999618530273, 5.0, 11.399999618530273, 9.800000190734863, 3.1999998092651367, 0.0, 3.0, 2.200000047683716, 13.199999809265137, 9.0, 11.199999809265137, 5.799999713897705, 9.800000190734863, 6.199999809265137, 18.600000381469727, 16.399999618530273, 9.399999618530273, 3.5999999046325684, 11.0, 9.0, 16.600000381469727, 14.799999237060547, 7.399999618530273, 3.0, 12.59999942779541, 9.800000190734863, 15.799999237060547, 10.199999809265137, 9.0, 5.799999713897705, 7.0, 6.0, 15.199999809265137, 12.199999809265137, 8.0, 3.1999998092651367, 10.199999809265137, 6.599999904632568, 13.799999237060547, 13.199999809265137, 10.59999942779541, 3.1999998092651367] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [10.800000190734863, 7.599999904632568, 8.399999618530273, 5.0, 3.0, 2.200000047683716, 9.800000190734863, 6.199999809265137, 11.0, 9.0, 12.59999942779541, 9.800000190734863, 7.0, 6.0, 10.199999809265137, 6.599999904632568] +k eval/a1:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [16.19999885559082, 11.800000190734863, 11.399999618530273, 9.800000190734863, 13.199999809265137, 9.0, 18.600000381469727, 16.399999618530273, 16.600000381469727, 14.799999237060547, 15.799999237060547, 10.199999809265137, 15.199999809265137, 12.199999809265137, 13.799999237060547, 13.199999809265137] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [14.399999618530273, 13.199999809265137, 3.1999998092651367, 0.0, 11.199999809265137, 5.799999713897705, 9.399999618530273, 3.5999999046325684, 7.399999618530273, 3.0, 9.0, 5.799999713897705, 8.0, 3.1999998092651367, 10.59999942779541, 3.1999998092651367] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 9.467+/- 0.64 (max: 18.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 6.937+/- 1.036 (max: 14.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 7.825+/- 0.7303 (max: 12.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 13.64+/- 0.6922 (max: 18.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 10.98+/- 0.4554 (max: 17.82) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 8.792+/- 0.6941 (max: 12.35) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 10.29+/- 0.4721 (max: 14.04) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 13.85+/- 0.586 (max: 17.82) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.05271+/- 0.009419 (max: 0.22) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0125+/- 0.006021 (max: 0.09) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.02625+/- 0.007296 (max: 0.11) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.1194+/- 0.01714 (max: 0.22) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 2.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 9.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 6.258 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 9.474 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-LSTM_SEED1 against population in Overcooked-AsymmAdvantages6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.19999998807907104, 6.799999713897705, 0.3999999761581421, 13.399999618530273, 0.5999999642372131, 6.799999713897705, 0.3999999761581421, 9.399999618530273, 0.3999999761581421, 42.39999771118164, 0.19999998807907104, 6.199999809265137, 0.3999999761581421, 32.599998474121094, 0.3999999761581421, 64.5999984741211, 0.5999999642372131, 14.59999942779541, 0.19999998807907104, 23.799999237060547, 1.0, 14.799999237060547, 0.7999999523162842, 2.799999952316284, 0.5999999642372131, 5.799999713897705, 0.5999999642372131, 22.0, 0.19999998807907104, 7.799999713897705, 0.5999999642372131, 4.799999713897705, 0.5999999642372131, 30.399999618530273, 0.3999999761581421, 5.599999904632568, 0.19999998807907104, 8.59999942779541, 0.3999999761581421, 19.19999885559082, 0.19999998807907104, 9.59999942779541, 0.5999999642372131, 9.0, 0.19999998807907104, 25.0, 1.5999999046325684, 7.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.19999998807907104, 6.799999713897705, 0.3999999761581421, 9.399999618530273, 0.3999999761581421, 32.599998474121094, 0.19999998807907104, 23.799999237060547, 0.5999999642372131, 5.799999713897705, 0.5999999642372131, 4.799999713897705, 0.19999998807907104, 8.59999942779541, 0.5999999642372131, 9.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.3999999761581421, 13.399999618530273, 0.3999999761581421, 42.39999771118164, 0.3999999761581421, 64.5999984741211, 1.0, 14.799999237060547, 0.5999999642372131, 22.0, 0.5999999642372131, 30.399999618530273, 0.3999999761581421, 19.19999885559082, 0.19999998807907104, 25.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.5999999642372131, 6.799999713897705, 0.19999998807907104, 6.199999809265137, 0.5999999642372131, 14.59999942779541, 0.7999999523162842, 2.799999952316284, 0.19999998807907104, 7.799999713897705, 0.3999999761581421, 5.599999904632568, 0.19999998807907104, 9.59999942779541, 1.5999999046325684, 7.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 8.433+/- 1.869 (max: 64.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 4.062+/- 1.079 (max: 14.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 6.5+/- 2.328 (max: 32.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 14.74+/- 4.699 (max: 64.6) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 9.568+/- 1.171 (max: 32.6) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 7.418+/- 1.205 (max: 15.96) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 8.482+/- 1.798 (max: 26.63) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 12.8+/- 2.67 (max: 32.6) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.09896+/- 0.02721 (max: 0.83) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.02562+/- 0.01004 (max: 0.15) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0675+/- 0.03574 (max: 0.51) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.2037+/- 0.06651 (max: 0.83) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.2 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 1.99 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-LSTM_SEED1 against population in Overcooked-CrampedRoom6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [8.800000190734863, 10.800000190734863, 65.79999542236328, 59.0, 84.0, 88.0, 13.199999809265137, 16.799999237060547, 70.5999984741211, 69.0, 88.5999984741211, 101.0, 18.0, 17.19999885559082, 88.79999542236328, 85.19999694824219, 90.4000015258789, 94.4000015258789, 18.0, 16.600000381469727, 85.4000015258789, 81.0, 74.0, 80.0, 10.800000190734863, 9.199999809265137, 68.0, 64.79999542236328, 88.4000015258789, 82.79999542236328, 14.799999237060547, 18.600000381469727, 82.4000015258789, 81.0, 88.0, 93.19999694824219, 16.0, 15.399999618530273, 78.19999694824219, 80.0, 91.19999694824219, 88.79999542236328, 20.600000381469727, 19.0, 84.4000015258789, 80.79999542236328, 89.19999694824219, 88.19999694824219] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [8.800000190734863, 10.800000190734863, 13.199999809265137, 16.799999237060547, 18.0, 17.19999885559082, 18.0, 16.600000381469727, 10.800000190734863, 9.199999809265137, 14.799999237060547, 18.600000381469727, 16.0, 15.399999618530273, 20.600000381469727, 19.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [65.79999542236328, 59.0, 70.5999984741211, 69.0, 88.79999542236328, 85.19999694824219, 85.4000015258789, 81.0, 68.0, 64.79999542236328, 82.4000015258789, 81.0, 78.19999694824219, 80.0, 84.4000015258789, 80.79999542236328] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [84.0, 88.0, 88.5999984741211, 101.0, 90.4000015258789, 94.4000015258789, 74.0, 80.0, 88.4000015258789, 82.79999542236328, 88.0, 93.19999694824219, 91.19999694824219, 88.79999542236328, 89.19999694824219, 88.19999694824219] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 59.97+/- 4.757 (max: 101.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 88.14+/- 1.525 (max: 101.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 15.24+/- 0.9125 (max: 20.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 76.52+/- 2.235 (max: 88.8) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 22.72+/- 0.9762 (max: 32.03) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 27.95+/- 0.5288 (max: 32.03) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.67+/- 0.3485 (max: 15.75) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 26.55+/- 0.5405 (max: 30.32) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6831+/- 0.05542 (max: 0.99) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9687+/- 0.004644 (max: 0.99) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1506+/- 0.01847 (max: 0.26) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.93+/- 0.01021 (max: 0.98) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 8.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 74.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 8.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 59.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 11.11 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 23.83 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 11.11 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 22.86 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.93 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.86 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +-------------------------------------------------------------------------------------------------- + + + + + + + + + + + + +Evaluating PAIRED_CNN-LSTM_SEED2 against population in Overcooked-CoordRing6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [3.0, 4.400000095367432, 20.799999237060547, 16.399999618530273, 21.600000381469727, 23.19999885559082, 16.600000381469727, 16.600000381469727, 31.799999237060547, 32.79999923706055, 22.399999618530273, 21.399999618530273, 12.799999237060547, 13.399999618530273, 22.19999885559082, 26.0, 21.0, 20.0, 9.800000190734863, 9.199999809265137, 18.600000381469727, 18.399999618530273, 17.799999237060547, 15.199999809265137, 15.59999942779541, 18.600000381469727, 20.600000381469727, 21.600000381469727, 23.399999618530273, 23.0, 11.199999809265137, 11.199999809265137, 24.19999885559082, 22.0, 39.39999771118164, 40.599998474121094, 2.200000047683716, 3.799999952316284, 17.19999885559082, 18.799999237060547, 25.399999618530273, 25.399999618530273, 11.199999809265137, 11.0, 15.0, 12.399999618530273, 20.19999885559082, 22.19999885559082] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [3.0, 4.400000095367432, 16.600000381469727, 16.600000381469727, 12.799999237060547, 13.399999618530273, 9.800000190734863, 9.199999809265137, 15.59999942779541, 18.600000381469727, 11.199999809265137, 11.199999809265137, 2.200000047683716, 3.799999952316284, 11.199999809265137, 11.0] +k eval/a1:test_return:Overcooked-CoordRing6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [20.799999237060547, 16.399999618530273, 31.799999237060547, 32.79999923706055, 22.19999885559082, 26.0, 18.600000381469727, 18.399999618530273, 20.600000381469727, 21.600000381469727, 24.19999885559082, 22.0, 17.19999885559082, 18.799999237060547, 15.0, 12.399999618530273] +k eval/a1:test_return:Overcooked-CoordRing6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [21.600000381469727, 23.19999885559082, 22.399999618530273, 21.399999618530273, 21.0, 20.0, 17.799999237060547, 15.199999809265137, 23.399999618530273, 23.0, 39.39999771118164, 40.599998474121094, 25.399999618530273, 25.399999618530273, 20.19999885559082, 22.19999885559082] +k eval/a1:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CoordRing6_9 | 18.57+/- 1.172 (max: 40.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 23.89+/- 1.698 (max: 40.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 10.66+/- 1.276 (max: 18.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 21.17+/- 1.379 (max: 32.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 12.16+/- 0.3026 (max: 16.07) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.21+/- 0.2927 (max: 13.52) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 11.88+/- 0.7115 (max: 16.01) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 13.4+/- 0.3164 (max: 16.07) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.2098+/- 0.02699 (max: 0.9) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.2856+/- 0.05881 (max: 0.9) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.08687+/- 0.01724 (max: 0.23) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.2569+/- 0.03939 (max: 0.6) | +| eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 2.2 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 15.2 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.2 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 12.4 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 6.258 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 9.404 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.258 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.93 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.1 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.06 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-LSTM_SEED2 against population in Overcooked-ForcedCoord6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [1.5999999046325684, 0.0, 2.799999952316284, 0.0, 0.5999999642372131, 0.0, 1.7999999523162842, 0.0, 2.0, 0.0, 1.0, 0.0, 1.1999999284744263, 0.0, 2.200000047683716, 0.0, 0.5999999642372131, 0.0, 2.200000047683716, 0.0, 1.0, 0.0, 1.1999999284744263, 0.0, 1.5999999046325684, 0.0, 2.3999998569488525, 0.0, 0.7999999523162842, 0.0, 1.7999999523162842, 0.0, 1.0, 0.0, 0.19999998807907104, 0.0, 1.1999999284744263, 0.0, 1.0, 0.0, 1.0, 0.0, 2.5999999046325684, 0.0, 2.5999999046325684, 0.0, 2.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [1.5999999046325684, 0.0, 1.7999999523162842, 0.0, 1.1999999284744263, 0.0, 2.200000047683716, 0.0, 1.5999999046325684, 0.0, 1.7999999523162842, 0.0, 1.1999999284744263, 0.0, 2.5999999046325684, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [2.799999952316284, 0.0, 2.0, 0.0, 2.200000047683716, 0.0, 1.0, 0.0, 2.3999998569488525, 0.0, 1.0, 0.0, 1.0, 0.0, 2.5999999046325684, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.5999999642372131, 0.0, 1.0, 0.0, 0.5999999642372131, 0.0, 1.1999999284744263, 0.0, 0.7999999523162842, 0.0, 0.19999998807907104, 0.0, 1.0, 0.0, 2.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.7583+/- 0.1323 (max: 2.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.4625+/- 0.1502 (max: 2.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.875+/- 0.2401 (max: 2.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.9375+/- 0.2749 (max: 2.8) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 2.552+/- 0.3929 (max: 6.94) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 2.013+/- 0.5566 (max: 6.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 2.799+/- 0.7319 (max: 6.726) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 2.844+/- 0.7594 (max: 6.94) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-LSTM_SEED2 against population in Overcooked-CounterCircuit6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [5.199999809265137, 2.3999998569488525, 7.0, 5.799999713897705, 4.799999713897705, 5.400000095367432, 5.0, 2.0, 6.599999904632568, 3.5999999046325684, 0.7999999523162842, 0.19999998807907104, 4.799999713897705, 1.0, 14.199999809265137, 4.0, 12.59999942779541, 5.0, 4.400000095367432, 2.0, 14.199999809265137, 11.0, 5.400000095367432, 1.5999999046325684, 6.399999618530273, 3.799999952316284, 17.0, 8.800000190734863, 9.800000190734863, 1.1999999284744263, 6.0, 3.3999998569488525, 10.0, 5.0, 4.199999809265137, 3.1999998092651367, 4.199999809265137, 3.1999998092651367, 6.599999904632568, 6.599999904632568, 2.200000047683716, 0.7999999523162842, 5.599999904632568, 1.1999999284744263, 10.0, 7.799999713897705, 4.599999904632568, 0.7999999523162842] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [5.199999809265137, 2.3999998569488525, 5.0, 2.0, 4.799999713897705, 1.0, 4.400000095367432, 2.0, 6.399999618530273, 3.799999952316284, 6.0, 3.3999998569488525, 4.199999809265137, 3.1999998092651367, 5.599999904632568, 1.1999999284744263] +k eval/a1:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [7.0, 5.799999713897705, 6.599999904632568, 3.5999999046325684, 14.199999809265137, 4.0, 14.199999809265137, 11.0, 17.0, 8.800000190734863, 10.0, 5.0, 6.599999904632568, 6.599999904632568, 10.0, 7.799999713897705] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [4.799999713897705, 5.400000095367432, 0.7999999523162842, 0.19999998807907104, 12.59999942779541, 5.0, 5.400000095367432, 1.5999999046325684, 9.800000190734863, 1.1999999284744263, 4.199999809265137, 3.1999998092651367, 2.200000047683716, 0.7999999523162842, 4.599999904632568, 0.7999999523162842] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 5.446+/- 0.5527 (max: 17.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 3.912+/- 0.8565 (max: 12.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 3.787+/- 0.4248 (max: 6.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 8.637+/- 0.9685 (max: 17.0) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 8.49+/- 0.3948 (max: 15.07) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 7.001+/- 0.6796 (max: 11.19) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 7.539+/- 0.3968 (max: 9.33) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 10.93+/- 0.4909 (max: 15.07) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.01646+/- 0.004501 (max: 0.16) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.005+/- 0.002582 (max: 0.04) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.04375+/- 0.0104 (max: 0.16) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 1.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 3.6 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 1.99 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 1.99 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 4.359 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 7.684 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-LSTM_SEED2 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 7.599999904632568, 0.19999998807907104, 17.600000381469727, 0.0, 7.799999713897705, 0.0, 9.800000190734863, 0.0, 40.79999923706055, 0.0, 3.799999952316284, 0.0, 37.79999923706055, 0.7999999523162842, 60.599998474121094, 0.0, 15.399999618530273, 0.0, 20.399999618530273, 0.0, 14.59999942779541, 0.19999998807907104, 11.399999618530273, 0.0, 7.199999809265137, 0.0, 29.599998474121094, 0.0, 10.59999942779541, 0.0, 7.199999809265137, 0.0, 34.599998474121094, 0.0, 5.799999713897705, 0.0, 8.399999618530273, 0.0, 21.399999618530273, 0.19999998807907104, 12.199999809265137, 0.0, 10.59999942779541, 0.19999998807907104, 28.19999885559082, 0.19999998807907104, 6.199999809265137] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 7.599999904632568, 0.0, 9.800000190734863, 0.0, 37.79999923706055, 0.0, 20.399999618530273, 0.0, 7.199999809265137, 0.0, 7.199999809265137, 0.0, 8.399999618530273, 0.0, 10.59999942779541] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9, v [0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.19999998807907104, 17.600000381469727, 0.0, 40.79999923706055, 0.7999999523162842, 60.599998474121094, 0.0, 14.59999942779541, 0.0, 29.599998474121094, 0.0, 34.599998474121094, 0.0, 21.399999618530273, 0.19999998807907104, 28.19999885559082] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 7.799999713897705, 0.0, 3.799999952316284, 0.0, 15.399999618530273, 0.19999998807907104, 11.399999618530273, 0.0, 10.59999942779541, 0.0, 5.799999713897705, 0.19999998807907104, 12.199999809265137, 0.19999998807907104, 6.199999809265137] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 8.987+/- 1.926 (max: 60.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 4.612+/- 1.347 (max: 15.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 6.812+/- 2.535 (max: 37.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 15.54+/- 4.71 (max: 60.6) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 8.296+/- 1.283 (max: 33.64) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 6.591+/- 1.588 (max: 16.94) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 6.8+/- 1.918 (max: 23.13) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 11.5+/- 2.874 (max: 33.64) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.1081+/- 0.02877 (max: 0.77) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.03437+/- 0.01281 (max: 0.17) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.06562+/- 0.03807 (max: 0.59) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.2244+/- 0.06917 (max: 0.77) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.1+/- 0.01459 (max: 0.2) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.1+/- 0.02582 (max: 0.2) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.1+/- 0.02582 (max: 0.2) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.1+/- 0.02582 (max: 0.2) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.995+/- 0.1451 (max: 1.99) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.995+/- 0.2569 (max: 1.99) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.995+/- 0.2569 (max: 1.99) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.995+/- 0.2569 (max: 1.99) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-LSTM_SEED2 against population in Overcooked-CrampedRoom6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [11.0, 8.0, 65.19999694824219, 66.79999542236328, 87.79999542236328, 95.4000015258789, 13.799999237060547, 14.199999809265137, 72.5999984741211, 69.4000015258789, 101.5999984741211, 103.0, 16.799999237060547, 19.600000381469727, 89.79999542236328, 95.19999694824219, 98.79999542236328, 108.0, 17.0, 17.19999885559082, 86.0, 82.5999984741211, 77.79999542236328, 79.0, 8.59999942779541, 10.800000190734863, 70.79999542236328, 70.19999694824219, 90.5999984741211, 97.0, 21.0, 20.0, 87.19999694824219, 88.19999694824219, 103.5999984741211, 102.0, 17.0, 17.0, 77.5999984741211, 82.4000015258789, 105.5999984741211, 103.79999542236328, 20.600000381469727, 17.799999237060547, 89.4000015258789, 83.0, 96.19999694824219, 98.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [11.0, 8.0, 13.799999237060547, 14.199999809265137, 16.799999237060547, 19.600000381469727, 17.0, 17.19999885559082, 8.59999942779541, 10.800000190734863, 21.0, 20.0, 17.0, 17.0, 20.600000381469727, 17.799999237060547] +k eval/a1:test_return:Overcooked-CrampedRoom6_9, v [0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:low, v [0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [65.19999694824219, 66.79999542236328, 72.5999984741211, 69.4000015258789, 89.79999542236328, 95.19999694824219, 86.0, 82.5999984741211, 70.79999542236328, 70.19999694824219, 87.19999694824219, 88.19999694824219, 77.5999984741211, 82.4000015258789, 89.4000015258789, 83.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:mid, v [0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [87.79999542236328, 95.4000015258789, 101.5999984741211, 103.0, 98.79999542236328, 108.0, 77.79999542236328, 79.0, 90.5999984741211, 97.0, 103.5999984741211, 102.0, 105.5999984741211, 103.79999542236328, 96.19999694824219, 98.79999542236328] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:high, v [0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 64.08+/- 5.217 (max: 108.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 96.81+/- 2.231 (max: 108.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 15.65+/- 1.042 (max: 21.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 79.77+/- 2.361 (max: 95.2) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 22.7+/- 0.9937 (max: 32.85) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 27.32+/- 0.8808 (max: 32.85) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.98+/- 0.4449 (max: 17.52) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 26.82+/- 0.8264 (max: 31.29) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6883+/- 0.05545 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9731+/- 0.004892 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1556+/- 0.01958 (max: 0.29) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9362+/- 0.00875 (max: 0.98) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.1+/- 0.01459 (max: 0.2) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.1+/- 0.02582 (max: 0.2) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.1+/- 0.02582 (max: 0.2) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.1+/- 0.02582 (max: 0.2) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.995+/- 0.1451 (max: 1.99) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.995+/- 0.2569 (max: 1.99) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.995+/- 0.2569 (max: 1.99) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.995+/- 0.2569 (max: 1.99) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 8.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 77.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 8.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 65.2 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 11.05 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 21.24 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 11.05 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 20.25 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.92 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.88 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------- + + + + + + + +Evaluating PAIRED_CNN-LSTM_SEED3 against population in Overcooked-CoordRing6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [5.199999809265137, 6.399999618530273, 23.19999885559082, 21.600000381469727, 24.0, 25.599998474121094, 16.399999618530273, 15.199999809265137, 33.39999771118164, 34.79999923706055, 27.799999237060547, 28.399999618530273, 15.0, 15.0, 28.19999885559082, 28.0, 27.0, 25.799999237060547, 16.399999618530273, 11.199999809265137, 19.600000381469727, 25.19999885559082, 24.799999237060547, 23.799999237060547, 19.0, 18.799999237060547, 26.399999618530273, 22.600000381469727, 29.0, 28.799999237060547, 14.399999618530273, 13.399999618530273, 28.399999618530273, 28.399999618530273, 34.0, 34.39999771118164, 3.0, 2.799999952316284, 19.0, 16.799999237060547, 29.0, 29.0, 10.800000190734863, 10.199999809265137, 14.399999618530273, 13.199999809265137, 17.799999237060547, 16.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [5.199999809265137, 6.399999618530273, 16.399999618530273, 15.199999809265137, 15.0, 15.0, 16.399999618530273, 11.199999809265137, 19.0, 18.799999237060547, 14.399999618530273, 13.399999618530273, 3.0, 2.799999952316284, 10.800000190734863, 10.199999809265137] +k eval/a1:test_return:Overcooked-CoordRing6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [23.19999885559082, 21.600000381469727, 33.39999771118164, 34.79999923706055, 28.19999885559082, 28.0, 19.600000381469727, 25.19999885559082, 26.399999618530273, 22.600000381469727, 28.399999618530273, 28.399999618530273, 19.0, 16.799999237060547, 14.399999618530273, 13.199999809265137] +k eval/a1:test_return:Overcooked-CoordRing6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [24.0, 25.599998474121094, 27.799999237060547, 28.399999618530273, 27.0, 25.799999237060547, 24.799999237060547, 23.799999237060547, 29.0, 28.799999237060547, 34.0, 34.39999771118164, 29.0, 29.0, 17.799999237060547, 16.399999618530273] +k eval/a1:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 20.87+/- 1.205 (max: 34.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 26.6+/- 1.196 (max: 34.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 12.07+/- 1.32 (max: 19.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 23.95+/- 1.578 (max: 34.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 13.13+/- 0.354 (max: 19.7) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 13.42+/- 0.6714 (max: 19.7) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.13+/- 0.6485 (max: 16.34) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 13.84+/- 0.4485 (max: 17.93) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.2823+/- 0.02862 (max: 0.72) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.4081+/- 0.04244 (max: 0.72) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0975+/- 0.01856 (max: 0.21) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.3412+/- 0.0446 (max: 0.68) | +| eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 2.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 16.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 13.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 6.94 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 10.16 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.94 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.39 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.06 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.05 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-LSTM_SEED3 against population in Overcooked-ForcedCoord6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [2.5999999046325684, 0.0, 1.5999999046325684, 0.19999998807907104, 1.7999999523162842, 0.19999998807907104, 1.7999999523162842, 0.0, 1.5999999046325684, 0.3999999761581421, 0.5999999642372131, 0.0, 2.200000047683716, 0.0, 2.0, 0.0, 0.3999999761581421, 0.0, 3.0, 0.0, 1.7999999523162842, 0.19999998807907104, 0.3999999761581421, 0.0, 3.0, 0.3999999761581421, 4.599999904632568, 0.0, 4.199999809265137, 0.19999998807907104, 2.200000047683716, 0.19999998807907104, 0.5999999642372131, 0.19999998807907104, 0.7999999523162842, 0.19999998807907104, 1.0, 0.0, 1.1999999284744263, 0.19999998807907104, 0.7999999523162842, 0.19999998807907104, 2.200000047683716, 0.0, 3.3999998569488525, 0.0, 5.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [2.5999999046325684, 0.0, 1.7999999523162842, 0.0, 2.200000047683716, 0.0, 3.0, 0.0, 3.0, 0.3999999761581421, 2.200000047683716, 0.19999998807907104, 1.0, 0.0, 2.200000047683716, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [1.5999999046325684, 0.19999998807907104, 1.5999999046325684, 0.3999999761581421, 2.0, 0.0, 1.7999999523162842, 0.19999998807907104, 4.599999904632568, 0.0, 0.5999999642372131, 0.19999998807907104, 1.1999999284744263, 0.19999998807907104, 3.3999998569488525, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [1.7999999523162842, 0.19999998807907104, 0.5999999642372131, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 4.199999809265137, 0.19999998807907104, 0.7999999523162842, 0.19999998807907104, 0.7999999523162842, 0.19999998807907104, 5.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.071+/- 0.1928 (max: 5.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.925+/- 0.3781 (max: 5.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.162+/- 0.3034 (max: 3.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 1.125+/- 0.3351 (max: 4.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 3.34+/- 0.4048 (max: 9.539) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 3.014+/- 0.7032 (max: 9.539) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 3.416+/- 0.7681 (max: 7.141) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 3.589+/- 0.6667 (max: 8.417) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0004167+/- 0.0004167 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.00125+/- 0.00125 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-LSTM_SEED3 against population in Overcooked-CounterCircuit6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [6.199999809265137, 4.400000095367432, 8.800000190734863, 4.599999904632568, 5.0, 5.0, 4.599999904632568, 3.1999998092651367, 2.5999999046325684, 3.1999998092651367, 1.5999999046325684, 1.0, 8.0, 4.0, 13.799999237060547, 10.800000190734863, 16.0, 8.199999809265137, 5.799999713897705, 4.199999809265137, 17.600000381469727, 12.0, 8.0, 5.599999904632568, 4.799999713897705, 7.399999618530273, 17.600000381469727, 12.59999942779541, 10.0, 4.400000095367432, 10.0, 6.799999713897705, 10.399999618530273, 7.0, 6.599999904632568, 5.199999809265137, 5.199999809265137, 4.599999904632568, 6.0, 6.199999809265137, 3.799999952316284, 3.0, 6.799999713897705, 5.400000095367432, 8.399999618530273, 7.199999809265137, 6.0, 2.799999952316284] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [6.199999809265137, 4.400000095367432, 4.599999904632568, 3.1999998092651367, 8.0, 4.0, 5.799999713897705, 4.199999809265137, 4.799999713897705, 7.399999618530273, 10.0, 6.799999713897705, 5.199999809265137, 4.599999904632568, 6.799999713897705, 5.400000095367432] +k eval/a1:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [8.800000190734863, 4.599999904632568, 2.5999999046325684, 3.1999998092651367, 13.799999237060547, 10.800000190734863, 17.600000381469727, 12.0, 17.600000381469727, 12.59999942779541, 10.399999618530273, 7.0, 6.0, 6.199999809265137, 8.399999618530273, 7.199999809265137] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [5.0, 5.0, 1.5999999046325684, 1.0, 16.0, 8.199999809265137, 8.0, 5.599999904632568, 10.0, 4.400000095367432, 6.599999904632568, 5.199999809265137, 3.799999952316284, 3.0, 6.0, 2.799999952316284] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 6.925+/- 0.5549 (max: 17.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 5.762+/- 0.9085 (max: 16.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 5.712+/- 0.4378 (max: 10.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 9.3+/- 1.146 (max: 17.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 9.858+/- 0.3225 (max: 15.28) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 8.5+/- 0.4769 (max: 12.33) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 9.462+/- 0.2841 (max: 11.83) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 11.61+/- 0.5752 (max: 15.28) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.02917+/- 0.006549 (max: 0.19) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0075+/- 0.006862 (max: 0.11) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.015+/- 0.003979 (max: 0.05) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.065+/- 0.01449 (max: 0.19) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 1.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 1.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 3.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 2.6 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 4.359 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 4.359 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 7.859 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 7.297 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-LSTM_SEED3 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.3999999761581421, 4.400000095367432, 1.399999976158142, 21.0, 1.0, 5.799999713897705, 0.19999998807907104, 8.399999618530273, 0.0, 40.599998474121094, 1.1999999284744263, 14.59999942779541, 0.19999998807907104, 42.39999771118164, 0.3999999761581421, 61.599998474121094, 1.0, 23.19999885559082, 0.5999999642372131, 25.599998474121094, 0.3999999761581421, 18.399999618530273, 0.3999999761581421, 16.600000381469727, 0.3999999761581421, 5.799999713897705, 0.19999998807907104, 29.599998474121094, 0.3999999761581421, 14.799999237060547, 0.3999999761581421, 6.799999713897705, 0.3999999761581421, 30.799999237060547, 0.7999999523162842, 12.199999809265137, 0.19999998807907104, 6.399999618530273, 0.5999999642372131, 28.599998474121094, 1.5999999046325684, 16.399999618530273, 1.0, 7.0, 0.3999999761581421, 27.0, 1.399999976158142, 4.400000095367432] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.3999999761581421, 4.400000095367432, 0.19999998807907104, 8.399999618530273, 0.19999998807907104, 42.39999771118164, 0.5999999642372131, 25.599998474121094, 0.3999999761581421, 5.799999713897705, 0.3999999761581421, 6.799999713897705, 0.19999998807907104, 6.399999618530273, 1.0, 7.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [1.399999976158142, 21.0, 0.0, 40.599998474121094, 0.3999999761581421, 61.599998474121094, 0.3999999761581421, 18.399999618530273, 0.19999998807907104, 29.599998474121094, 0.3999999761581421, 30.799999237060547, 0.5999999642372131, 28.599998474121094, 0.3999999761581421, 27.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [1.0, 5.799999713897705, 1.1999999284744263, 14.59999942779541, 1.0, 23.19999885559082, 0.3999999761581421, 16.600000381469727, 0.3999999761581421, 14.799999237060547, 0.7999999523162842, 12.199999809265137, 1.5999999046325684, 16.399999618530273, 1.399999976158142, 4.400000095367432] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 10.15+/- 2.007 (max: 61.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 7.237+/- 1.923 (max: 23.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 6.887+/- 2.855 (max: 42.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 16.34+/- 4.711 (max: 61.6) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 10.47+/- 1.246 (max: 33.55) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 9.743+/- 1.548 (max: 19.94) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 8.38+/- 1.786 (max: 27.02) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 13.29+/- 2.856 (max: 33.55) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.129+/- 0.02943 (max: 0.81) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.07937+/- 0.02664 (max: 0.33) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0775+/- 0.04376 (max: 0.65) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.23+/- 0.06705 (max: 0.81) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.4 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 2.8 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-LSTM_SEED3 against population in Overcooked-CrampedRoom6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [8.800000190734863, 9.800000190734863, 64.4000015258789, 60.79999923706055, 89.4000015258789, 100.19999694824219, 16.19999885559082, 15.799999237060547, 69.4000015258789, 63.79999923706055, 96.0, 95.0, 16.399999618530273, 17.799999237060547, 77.0, 79.19999694824219, 90.5999984741211, 98.0, 17.600000381469727, 16.0, 81.5999984741211, 75.79999542236328, 66.79999542236328, 70.0, 11.800000190734863, 10.0, 67.4000015258789, 65.4000015258789, 84.4000015258789, 87.4000015258789, 17.19999885559082, 17.399999618530273, 81.0, 85.4000015258789, 92.0, 96.19999694824219, 18.19999885559082, 15.0, 70.5999984741211, 73.19999694824219, 93.4000015258789, 92.5999984741211, 18.19999885559082, 17.799999237060547, 74.19999694824219, 77.5999984741211, 96.5999984741211, 99.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [8.800000190734863, 9.800000190734863, 16.19999885559082, 15.799999237060547, 16.399999618530273, 17.799999237060547, 17.600000381469727, 16.0, 11.800000190734863, 10.0, 17.19999885559082, 17.399999618530273, 18.19999885559082, 15.0, 18.19999885559082, 17.799999237060547] +k eval/a1:test_return:Overcooked-CrampedRoom6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [64.4000015258789, 60.79999923706055, 69.4000015258789, 63.79999923706055, 77.0, 79.19999694824219, 81.5999984741211, 75.79999542236328, 67.4000015258789, 65.4000015258789, 81.0, 85.4000015258789, 70.5999984741211, 73.19999694824219, 74.19999694824219, 77.5999984741211] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [89.4000015258789, 100.19999694824219, 96.0, 95.0, 90.5999984741211, 98.0, 66.79999542236328, 70.0, 84.4000015258789, 87.4000015258789, 92.0, 96.19999694824219, 93.4000015258789, 92.5999984741211, 96.5999984741211, 99.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 59.55+/- 4.796 (max: 100.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 90.47+/- 2.405 (max: 100.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 15.25+/- 0.8121 (max: 18.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 72.92+/- 1.816 (max: 85.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 22.03+/- 0.9573 (max: 32.77) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 25.6+/- 0.6762 (max: 32.77) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.54+/- 0.3808 (max: 16.47) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 26.96+/- 0.8605 (max: 32.58) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.674+/- 0.05552 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9669+/- 0.01106 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1412+/- 0.01643 (max: 0.24) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9137+/- 0.01036 (max: 0.98) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 8.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 66.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 8.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 60.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.7 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 20.73 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.7 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 22.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.83 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.82 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------- + + + + + + + +Evaluating ACCEL_CNN-LSTM_SEED1 against population in Overcooked-CoordRing6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [2.0, 2.0, 7.199999809265137, 5.599999904632568, 4.400000095367432, 5.0, 7.399999618530273, 7.599999904632568, 11.0, 12.399999618530273, 6.0, 7.399999618530273, 6.599999904632568, 8.0, 11.59999942779541, 9.399999618530273, 6.599999904632568, 8.59999942779541, 6.199999809265137, 6.799999713897705, 7.0, 10.0, 6.199999809265137, 8.800000190734863, 7.799999713897705, 8.800000190734863, 3.5999999046325684, 3.799999952316284, 6.199999809265137, 6.399999618530273, 3.3999998569488525, 5.400000095367432, 10.0, 11.59999942779541, 16.799999237060547, 15.799999237060547, 0.5999999642372131, 0.3999999761581421, 4.400000095367432, 7.0, 16.600000381469727, 21.600000381469727, 9.199999809265137, 11.399999618530273, 4.799999713897705, 6.199999809265137, 12.0, 17.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [2.0, 2.0, 7.399999618530273, 7.599999904632568, 6.599999904632568, 8.0, 6.199999809265137, 6.799999713897705, 7.799999713897705, 8.800000190734863, 3.3999998569488525, 5.400000095367432, 0.5999999642372131, 0.3999999761581421, 9.199999809265137, 11.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [7.199999809265137, 5.599999904632568, 11.0, 12.399999618530273, 11.59999942779541, 9.399999618530273, 7.0, 10.0, 3.5999999046325684, 3.799999952316284, 10.0, 11.59999942779541, 4.400000095367432, 7.0, 4.799999713897705, 6.199999809265137] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [4.400000095367432, 5.0, 6.0, 7.399999618530273, 6.599999904632568, 8.59999942779541, 6.199999809265137, 8.800000190734863, 6.199999809265137, 6.399999618530273, 16.799999237060547, 15.799999237060547, 16.600000381469727, 21.600000381469727, 12.0, 17.399999618530273] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 8.021+/- 0.6348 (max: 21.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 10.36+/- 1.37 (max: 21.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 5.85+/- 0.8148 (max: 11.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 7.85+/- 0.752 (max: 12.4) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 9.919+/- 0.3122 (max: 13.04) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 10.37+/- 0.3352 (max: 12.74) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 8.99+/- 0.7529 (max: 13.04) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 10.4+/- 0.3882 (max: 13.02) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.03146+/- 0.006267 (max: 0.22) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.04562+/- 0.01633 (max: 0.22) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.02062+/- 0.006421 (max: 0.09) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.02812+/- 0.006273 (max: 0.09) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 0.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 4.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 0.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 3.6 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 2.8 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 8.66 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 2.8 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 7.684 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-LSTM_SEED1 against population in Overcooked-ForcedCoord6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [1.0, 0.0, 1.7999999523162842, 0.3999999761581421, 1.399999976158142, 0.3999999761581421, 2.0, 0.3999999761581421, 2.5999999046325684, 0.7999999523162842, 2.5999999046325684, 0.5999999642372131, 1.399999976158142, 0.5999999642372131, 2.0, 0.3999999761581421, 1.0, 0.5999999642372131, 2.3999998569488525, 0.7999999523162842, 2.3999998569488525, 0.7999999523162842, 1.0, 0.7999999523162842, 1.5999999046325684, 0.5999999642372131, 3.0, 0.7999999523162842, 3.1999998092651367, 0.5999999642372131, 2.3999998569488525, 0.3999999761581421, 1.7999999523162842, 0.3999999761581421, 1.0, 0.7999999523162842, 0.7999999523162842, 0.0, 1.0, 0.3999999761581421, 0.19999998807907104, 0.5999999642372131, 2.5999999046325684, 0.5999999642372131, 2.3999998569488525, 0.19999998807907104, 3.3999998569488525, 0.3999999761581421] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [1.0, 0.0, 2.0, 0.3999999761581421, 1.399999976158142, 0.5999999642372131, 2.3999998569488525, 0.7999999523162842, 1.5999999046325684, 0.5999999642372131, 2.3999998569488525, 0.3999999761581421, 0.7999999523162842, 0.0, 2.5999999046325684, 0.5999999642372131] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [1.7999999523162842, 0.3999999761581421, 2.5999999046325684, 0.7999999523162842, 2.0, 0.3999999761581421, 2.3999998569488525, 0.7999999523162842, 3.0, 0.7999999523162842, 1.7999999523162842, 0.3999999761581421, 1.0, 0.3999999761581421, 2.3999998569488525, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [1.399999976158142, 0.3999999761581421, 2.5999999046325684, 0.5999999642372131, 1.0, 0.5999999642372131, 1.0, 0.7999999523162842, 3.1999998092651367, 0.5999999642372131, 1.0, 0.7999999523162842, 0.19999998807907104, 0.5999999642372131, 3.3999998569488525, 0.3999999761581421] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.196+/- 0.1328 (max: 3.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.162+/- 0.2498 (max: 3.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.1+/- 0.2153 (max: 2.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 1.325+/- 0.2351 (max: 3.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 4.307+/- 0.2572 (max: 7.513) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 4.302+/- 0.4049 (max: 7.513) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 4.018+/- 0.5133 (max: 6.726) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 4.601+/- 0.4277 (max: 7.141) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.2 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.2 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 1.99 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 1.99 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating ACCEL_CNN-LSTM_SEED1 against population in Overcooked-CounterCircuit6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [14.59999942779541, 9.0, 31.399999618530273, 16.600000381469727, 19.799999237060547, 17.799999237060547, 8.199999809265137, 5.799999713897705, 13.399999618530273, 9.399999618530273, 1.1999999284744263, 0.0, 0.0, 0.19999998807907104, 1.399999976158142, 0.3999999761581421, 1.5999999046325684, 1.399999976158142, 8.399999618530273, 5.599999904632568, 22.399999618530273, 13.799999237060547, 7.199999809265137, 2.3999998569488525, 12.199999809265137, 7.0, 20.0, 12.59999942779541, 3.3999998569488525, 2.200000047683716, 12.199999809265137, 9.399999618530273, 23.0, 10.199999809265137, 7.799999713897705, 2.799999952316284, 11.0, 5.0, 23.19999885559082, 12.199999809265137, 9.59999942779541, 6.0, 16.799999237060547, 11.59999942779541, 33.39999771118164, 19.399999618530273, 11.199999809265137, 5.799999713897705] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [14.59999942779541, 9.0, 8.199999809265137, 5.799999713897705, 0.0, 0.19999998807907104, 8.399999618530273, 5.599999904632568, 12.199999809265137, 7.0, 12.199999809265137, 9.399999618530273, 11.0, 5.0, 16.799999237060547, 11.59999942779541] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [31.399999618530273, 16.600000381469727, 13.399999618530273, 9.399999618530273, 1.399999976158142, 0.3999999761581421, 22.399999618530273, 13.799999237060547, 20.0, 12.59999942779541, 23.0, 10.199999809265137, 23.19999885559082, 12.199999809265137, 33.39999771118164, 19.399999618530273] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [19.799999237060547, 17.799999237060547, 1.1999999284744263, 0.0, 1.5999999046325684, 1.399999976158142, 7.199999809265137, 2.3999998569488525, 3.3999998569488525, 2.200000047683716, 7.799999713897705, 2.799999952316284, 9.59999942779541, 6.0, 11.199999809265137, 5.799999713897705] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 10.42+/- 1.158 (max: 33.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 6.262+/- 1.468 (max: 19.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 8.562+/- 1.158 (max: 16.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 16.42+/- 2.308 (max: 33.4) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 10.45+/- 0.6433 (max: 20.41) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 7.941+/- 0.8221 (max: 13.41) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 9.878+/- 0.94 (max: 13.81) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 13.53+/- 1.122 (max: 20.41) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.09021+/- 0.01991 (max: 0.57) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.02437+/- 0.01565 (max: 0.22) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.04+/- 0.01114 (max: 0.14) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.2062+/- 0.04473 (max: 0.57) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.4 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 2.8 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-LSTM_SEED1 against population in Overcooked-AsymmAdvantages6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [5.0, 2.3999998569488525, 11.800000190734863, 9.59999942779541, 4.400000095367432, 0.19999998807907104, 3.799999952316284, 4.599999904632568, 7.799999713897705, 41.599998474121094, 4.199999809265137, 0.19999998807907104, 4.0, 32.0, 9.199999809265137, 57.79999923706055, 3.3999998569488525, 8.399999618530273, 5.599999904632568, 13.59999942779541, 8.0, 5.400000095367432, 4.400000095367432, 0.5999999642372131, 2.200000047683716, 1.399999976158142, 8.0, 13.59999942779541, 9.199999809265137, 1.399999976158142, 3.0, 1.5999999046325684, 7.199999809265137, 24.399999618530273, 3.5999999046325684, 0.19999998807907104, 2.799999952316284, 3.1999998092651367, 7.399999618530273, 15.399999618530273, 4.400000095367432, 2.200000047683716, 3.1999998092651367, 3.3999998569488525, 11.0, 18.19999885559082, 5.199999809265137, 0.5999999642372131] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [5.0, 2.3999998569488525, 3.799999952316284, 4.599999904632568, 4.0, 32.0, 5.599999904632568, 13.59999942779541, 2.200000047683716, 1.399999976158142, 3.0, 1.5999999046325684, 2.799999952316284, 3.1999998092651367, 3.1999998092651367, 3.3999998569488525] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [11.800000190734863, 9.59999942779541, 7.799999713897705, 41.599998474121094, 9.199999809265137, 57.79999923706055, 8.0, 5.400000095367432, 8.0, 13.59999942779541, 7.199999809265137, 24.399999618530273, 7.399999618530273, 15.399999618530273, 11.0, 18.19999885559082] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [4.400000095367432, 0.19999998807907104, 4.199999809265137, 0.19999998807907104, 3.3999998569488525, 8.399999618530273, 4.400000095367432, 0.5999999642372131, 9.199999809265137, 1.399999976158142, 3.5999999046325684, 0.19999998807907104, 4.400000095367432, 2.200000047683716, 5.199999809265137, 0.5999999642372131] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 8.35+/- 1.562 (max: 57.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 3.287+/- 0.6988 (max: 9.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 5.737+/- 1.887 (max: 32.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 16.02+/- 3.578 (max: 57.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 11.03+/- 0.9006 (max: 30.19) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 6.996+/- 0.9011 (max: 13.39) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 9.519+/- 1.362 (max: 27.42) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 16.57+/- 1.294 (max: 30.19) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.08312+/- 0.02371 (max: 0.79) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.01312+/- 0.004539 (max: 0.06) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.04437+/- 0.02831 (max: 0.45) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1919+/- 0.05697 (max: 0.79) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 1.4 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 5.4 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 5.103 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 10.95 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.03 | +------------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-LSTM_SEED1 against population in Overcooked-CrampedRoom6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [8.199999809265137, 6.399999618530273, 50.0, 56.19999694824219, 66.79999542236328, 68.4000015258789, 9.800000190734863, 8.800000190734863, 53.79999923706055, 54.39999771118164, 70.0, 66.5999984741211, 12.399999618530273, 13.0, 62.19999694824219, 63.0, 56.599998474121094, 59.39999771118164, 11.800000190734863, 14.799999237060547, 54.39999771118164, 64.5999984741211, 39.39999771118164, 49.0, 5.599999904632568, 8.199999809265137, 51.19999694824219, 47.0, 61.599998474121094, 62.39999771118164, 11.800000190734863, 11.199999809265137, 63.599998474121094, 71.4000015258789, 75.4000015258789, 77.0, 10.199999809265137, 11.0, 57.39999771118164, 65.5999984741211, 67.19999694824219, 69.0, 15.799999237060547, 14.59999942779541, 62.19999694824219, 74.79999542236328, 74.5999984741211, 72.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [8.199999809265137, 6.399999618530273, 9.800000190734863, 8.800000190734863, 12.399999618530273, 13.0, 11.800000190734863, 14.799999237060547, 5.599999904632568, 8.199999809265137, 11.800000190734863, 11.199999809265137, 10.199999809265137, 11.0, 15.799999237060547, 14.59999942779541] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [50.0, 56.19999694824219, 53.79999923706055, 54.39999771118164, 62.19999694824219, 63.0, 54.39999771118164, 64.5999984741211, 51.19999694824219, 47.0, 63.599998474121094, 71.4000015258789, 57.39999771118164, 65.5999984741211, 62.19999694824219, 74.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [66.79999542236328, 68.4000015258789, 70.0, 66.5999984741211, 56.599998474121094, 59.39999771118164, 39.39999771118164, 49.0, 61.599998474121094, 62.39999771118164, 75.4000015258789, 77.0, 67.19999694824219, 69.0, 74.5999984741211, 72.79999542236328] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 45.03+/- 3.695 (max: 77.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 64.76+/- 2.499 (max: 77.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 10.85+/- 0.7368 (max: 15.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 59.49+/- 1.934 (max: 74.8) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 26.04+/- 1.36 (max: 36.25) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 32.06+/- 0.492 (max: 34.81) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.22+/- 0.4322 (max: 15.19) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 32.83+/- 0.7195 (max: 36.25) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.5617+/- 0.04909 (max: 0.93) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.8125+/- 0.02357 (max: 0.93) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.09687+/- 0.01341 (max: 0.19) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.7756+/- 0.01641 (max: 0.92) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 5.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 39.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 5.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 47.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.54 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 27.89 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.54 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 26.18 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.56 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.66 | +-------------------------------------------------------------------------------------------------- + + + + + + + +Evaluating ACCEL_CNN-LSTM_SEED2 against population in Overcooked-CoordRing6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [2.0, 3.799999952316284, 6.399999618530273, 7.199999809265137, 5.199999809265137, 5.400000095367432, 8.0, 8.399999618530273, 12.59999942779541, 16.19999885559082, 7.199999809265137, 10.199999809265137, 9.199999809265137, 9.59999942779541, 10.199999809265137, 13.199999809265137, 11.59999942779541, 11.0, 7.599999904632568, 7.399999618530273, 8.0, 8.199999809265137, 8.199999809265137, 8.199999809265137, 9.0, 10.199999809265137, 3.0, 3.0, 7.599999904632568, 8.800000190734863, 6.0, 7.799999713897705, 12.199999809265137, 15.199999809265137, 19.19999885559082, 17.399999618530273, 0.7999999523162842, 1.1999999284744263, 3.799999952316284, 6.0, 10.59999942779541, 16.0, 7.0, 10.59999942779541, 4.199999809265137, 7.0, 7.399999618530273, 14.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [2.0, 3.799999952316284, 8.0, 8.399999618530273, 9.199999809265137, 9.59999942779541, 7.599999904632568, 7.399999618530273, 9.0, 10.199999809265137, 6.0, 7.799999713897705, 0.7999999523162842, 1.1999999284744263, 7.0, 10.59999942779541] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [6.399999618530273, 7.199999809265137, 12.59999942779541, 16.19999885559082, 10.199999809265137, 13.199999809265137, 8.0, 8.199999809265137, 3.0, 3.0, 12.199999809265137, 15.199999809265137, 3.799999952316284, 6.0, 4.199999809265137, 7.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [5.199999809265137, 5.400000095367432, 7.199999809265137, 10.199999809265137, 11.59999942779541, 11.0, 8.199999809265137, 8.199999809265137, 7.599999904632568, 8.800000190734863, 19.19999885559082, 17.399999618530273, 10.59999942779541, 16.0, 7.399999618530273, 14.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 8.604+/- 0.5957 (max: 19.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 10.5+/- 1.048 (max: 19.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 6.787+/- 0.7919 (max: 10.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 8.525+/- 1.072 (max: 16.2) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 10.35+/- 0.2982 (max: 15.0) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 10.78+/- 0.3111 (max: 13.27) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 9.788+/- 0.6564 (max: 12.16) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 10.49+/- 0.5226 (max: 15.0) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.03708+/- 0.006489 (max: 0.19) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.04187+/- 0.01327 (max: 0.17) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0275+/- 0.005123 (max: 0.06) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.04187+/- 0.01358 (max: 0.19) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 0.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 5.2 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 0.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 3.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 3.919 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 8.773 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 3.919 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 7.681 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-LSTM_SEED2 against population in Overcooked-ForcedCoord6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [2.5999999046325684, 0.0, 3.5999999046325684, 0.19999998807907104, 1.5999999046325684, 0.7999999523162842, 2.200000047683716, 0.19999998807907104, 5.799999713897705, 0.3999999761581421, 2.799999952316284, 1.0, 3.799999952316284, 0.7999999523162842, 3.1999998092651367, 1.5999999046325684, 2.200000047683716, 2.3999998569488525, 3.799999952316284, 1.0, 3.1999998092651367, 0.3999999761581421, 2.0, 0.19999998807907104, 3.3999998569488525, 0.3999999761581421, 4.400000095367432, 1.399999976158142, 5.599999904632568, 0.3999999761581421, 2.200000047683716, 1.0, 2.5999999046325684, 0.19999998807907104, 1.5999999046325684, 0.7999999523162842, 2.0, 0.0, 2.0, 0.3999999761581421, 2.200000047683716, 1.0, 4.400000095367432, 0.5999999642372131, 4.400000095367432, 1.0, 6.199999809265137, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [2.5999999046325684, 0.0, 2.200000047683716, 0.19999998807907104, 3.799999952316284, 0.7999999523162842, 3.799999952316284, 1.0, 3.3999998569488525, 0.3999999761581421, 2.200000047683716, 1.0, 2.0, 0.0, 4.400000095367432, 0.5999999642372131] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [3.5999999046325684, 0.19999998807907104, 5.799999713897705, 0.3999999761581421, 3.1999998092651367, 1.5999999046325684, 3.1999998092651367, 0.3999999761581421, 4.400000095367432, 1.399999976158142, 2.5999999046325684, 0.19999998807907104, 2.0, 0.3999999761581421, 4.400000095367432, 1.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [1.5999999046325684, 0.7999999523162842, 2.799999952316284, 1.0, 2.200000047683716, 2.3999998569488525, 2.0, 0.19999998807907104, 5.599999904632568, 0.3999999761581421, 1.5999999046325684, 0.7999999523162842, 2.200000047683716, 1.0, 6.199999809265137, 0.19999998807907104] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.962+/- 0.2372 (max: 6.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.937+/- 0.4362 (max: 6.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.775+/- 0.371 (max: 4.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 2.175+/- 0.4423 (max: 5.8) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 5.226+/- 0.3505 (max: 9.506) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 5.273+/- 0.5376 (max: 9.25) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 4.879+/- 0.6822 (max: 8.34) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 5.526+/- 0.6217 (max: 9.506) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0004167+/- 0.0002915 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.000625+/- 0.000625 (max: 0.01) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.2 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.2 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 1.99 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 1.99 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating ACCEL_CNN-LSTM_SEED2 against population in Overcooked-CounterCircuit6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [10.59999942779541, 10.199999809265137, 17.399999618530273, 13.199999809265137, 13.799999237060547, 12.199999809265137, 5.400000095367432, 5.0, 11.399999618530273, 9.800000190734863, 0.3999999761581421, 0.0, 0.19999998807907104, 0.0, 4.199999809265137, 0.7999999523162842, 2.200000047683716, 2.200000047683716, 6.0, 4.799999713897705, 19.799999237060547, 11.399999618530273, 9.399999618530273, 5.199999809265137, 12.0, 7.199999809265137, 12.399999618530273, 9.59999942779541, 3.3999998569488525, 1.7999999523162842, 8.399999618530273, 8.399999618530273, 16.600000381469727, 9.199999809265137, 6.399999618530273, 5.400000095367432, 6.0, 4.799999713897705, 11.0, 9.0, 7.799999713897705, 7.199999809265137, 14.59999942779541, 9.59999942779541, 25.399999618530273, 13.799999237060547, 9.59999942779541, 5.599999904632568] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [10.59999942779541, 10.199999809265137, 5.400000095367432, 5.0, 0.19999998807907104, 0.0, 6.0, 4.799999713897705, 12.0, 7.199999809265137, 8.399999618530273, 8.399999618530273, 6.0, 4.799999713897705, 14.59999942779541, 9.59999942779541] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [17.399999618530273, 13.199999809265137, 11.399999618530273, 9.800000190734863, 4.199999809265137, 0.7999999523162842, 19.799999237060547, 11.399999618530273, 12.399999618530273, 9.59999942779541, 16.600000381469727, 9.199999809265137, 11.0, 9.0, 25.399999618530273, 13.799999237060547] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [13.799999237060547, 12.199999809265137, 0.3999999761581421, 0.0, 2.200000047683716, 2.200000047683716, 9.399999618530273, 5.199999809265137, 3.3999998569488525, 1.7999999523162842, 6.399999618530273, 5.400000095367432, 7.799999713897705, 7.199999809265137, 9.59999942779541, 5.599999904632568] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 8.35+/- 0.7755 (max: 25.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 5.787+/- 1.02 (max: 13.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 7.075+/- 0.9791 (max: 14.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 12.19+/- 1.458 (max: 25.4) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 10.42+/- 0.629 (max: 23.13) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 8.145+/- 0.8287 (max: 13.17) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 9.386+/- 0.9118 (max: 13.22) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 13.73+/- 1.027 (max: 23.13) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.05792+/- 0.01221 (max: 0.41) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0125+/- 0.007042 (max: 0.11) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.02937+/- 0.00915 (max: 0.12) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.1319+/- 0.02666 (max: 0.41) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.8 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 3.919 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-LSTM_SEED2 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [3.799999952316284, 2.200000047683716, 3.3999998569488525, 7.199999809265137, 3.799999952316284, 0.3999999761581421, 1.7999999523162842, 4.400000095367432, 6.199999809265137, 34.79999923706055, 3.1999998092651367, 0.7999999523162842, 5.599999904632568, 31.599998474121094, 4.599999904632568, 51.599998474121094, 5.799999713897705, 8.199999809265137, 3.5999999046325684, 14.0, 4.199999809265137, 5.199999809265137, 5.0, 1.5999999046325684, 3.0, 1.1999999284744263, 5.599999904632568, 16.399999618530273, 6.399999618530273, 1.7999999523162842, 3.5999999046325684, 2.3999998569488525, 4.199999809265137, 23.600000381469727, 5.799999713897705, 0.5999999642372131, 2.5999999046325684, 1.1999999284744263, 4.799999713897705, 14.199999809265137, 6.399999618530273, 1.399999976158142, 2.799999952316284, 3.5999999046325684, 4.799999713897705, 18.799999237060547, 6.199999809265137, 0.19999998807907104] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [3.799999952316284, 2.200000047683716, 1.7999999523162842, 4.400000095367432, 5.599999904632568, 31.599998474121094, 3.5999999046325684, 14.0, 3.0, 1.1999999284744263, 3.5999999046325684, 2.3999998569488525, 2.5999999046325684, 1.1999999284744263, 2.799999952316284, 3.5999999046325684] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [3.3999998569488525, 7.199999809265137, 6.199999809265137, 34.79999923706055, 4.599999904632568, 51.599998474121094, 4.199999809265137, 5.199999809265137, 5.599999904632568, 16.399999618530273, 4.199999809265137, 23.600000381469727, 4.799999713897705, 14.199999809265137, 4.799999713897705, 18.799999237060547] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [3.799999952316284, 0.3999999761581421, 3.1999998092651367, 0.7999999523162842, 5.799999713897705, 8.199999809265137, 5.0, 1.5999999046325684, 6.399999618530273, 1.7999999523162842, 5.799999713897705, 0.5999999642372131, 6.399999618530273, 1.399999976158142, 6.199999809265137, 0.19999998807907104] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 7.387+/- 1.42 (max: 51.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 3.6+/- 0.6651 (max: 8.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 5.462+/- 1.895 (max: 31.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 13.1+/- 3.398 (max: 51.6) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 10.14+/- 0.8291 (max: 29.15) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 7.411+/- 0.8157 (max: 10.91) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 9.115+/- 1.341 (max: 27.01) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 13.88+/- 1.58 (max: 29.15) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.07312+/- 0.02311 (max: 0.76) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.01312+/- 0.00395 (max: 0.04) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.04625+/- 0.03011 (max: 0.48) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.16+/- 0.05752 (max: 0.76) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 1.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 3.4 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 4.75 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 7.513 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating ACCEL_CNN-LSTM_SEED2 against population in Overcooked-CrampedRoom6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [8.0, 6.399999618530273, 63.39999771118164, 61.39999771118164, 70.4000015258789, 74.5999984741211, 10.800000190734863, 9.0, 60.19999694824219, 62.19999694824219, 76.79999542236328, 74.79999542236328, 13.199999809265137, 10.199999809265137, 66.5999984741211, 67.79999542236328, 61.79999923706055, 66.19999694824219, 11.800000190734863, 12.199999809265137, 73.4000015258789, 70.5999984741211, 55.0, 55.79999923706055, 5.799999713897705, 6.399999618530273, 53.599998474121094, 56.39999771118164, 64.4000015258789, 67.79999542236328, 12.799999237060547, 10.59999942779541, 70.19999694824219, 73.4000015258789, 81.4000015258789, 82.5999984741211, 11.399999618530273, 13.0, 62.19999694824219, 70.4000015258789, 72.0, 71.4000015258789, 14.799999237060547, 12.799999237060547, 66.5999984741211, 68.19999694824219, 74.5999984741211, 75.19999694824219] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [8.0, 6.399999618530273, 10.800000190734863, 9.0, 13.199999809265137, 10.199999809265137, 11.800000190734863, 12.199999809265137, 5.799999713897705, 6.399999618530273, 12.799999237060547, 10.59999942779541, 11.399999618530273, 13.0, 14.799999237060547, 12.799999237060547] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [63.39999771118164, 61.39999771118164, 60.19999694824219, 62.19999694824219, 66.5999984741211, 67.79999542236328, 73.4000015258789, 70.5999984741211, 53.599998474121094, 56.39999771118164, 70.19999694824219, 73.4000015258789, 62.19999694824219, 70.4000015258789, 66.5999984741211, 68.19999694824219] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [70.4000015258789, 74.5999984741211, 76.79999542236328, 74.79999542236328, 61.79999923706055, 66.19999694824219, 55.0, 55.79999923706055, 64.4000015258789, 67.79999542236328, 81.4000015258789, 82.5999984741211, 72.0, 71.4000015258789, 74.5999984741211, 75.19999694824219] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 48.76+/- 4.038 (max: 82.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 70.3+/- 2.017 (max: 82.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 10.57+/- 0.6841 (max: 14.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 65.41+/- 1.454 (max: 73.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 25.92+/- 1.398 (max: 37.72) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 31.87+/- 0.6411 (max: 35.73) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 12.76+/- 0.4283 (max: 15.97) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 33.13+/- 0.649 (max: 37.72) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.5917+/- 0.05304 (max: 0.95) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.8687+/- 0.01793 (max: 0.95) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.08375+/- 0.01068 (max: 0.15) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.8225+/- 0.01185 (max: 0.9) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 5.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 55.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 5.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 53.6 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 9.506 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 26.96 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 9.506 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 29.13 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.01 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.7 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.01 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.71 | +-------------------------------------------------------------------------------------------------- + + + + + + + + +Evaluating ACCEL_CNN-LSTM_SEED3 against population in Overcooked-CoordRing6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [6.0, 7.599999904632568, 17.399999618530273, 15.399999618530273, 24.399999618530273, 22.799999237060547, 15.0, 16.799999237060547, 32.0, 33.39999771118164, 20.600000381469727, 19.600000381469727, 12.799999237060547, 15.199999809265137, 24.799999237060547, 24.399999618530273, 23.19999885559082, 22.19999885559082, 17.0, 13.399999618530273, 16.0, 17.0, 23.0, 22.399999618530273, 14.0, 11.59999942779541, 22.399999618530273, 20.600000381469727, 24.19999885559082, 22.799999237060547, 15.399999618530273, 17.799999237060547, 28.599998474121094, 28.19999885559082, 37.39999771118164, 35.79999923706055, 7.799999713897705, 4.599999904632568, 21.799999237060547, 19.19999885559082, 29.399999618530273, 28.399999618530273, 12.399999618530273, 11.199999809265137, 15.0, 14.199999809265137, 15.799999237060547, 16.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [6.0, 7.599999904632568, 15.0, 16.799999237060547, 12.799999237060547, 15.199999809265137, 17.0, 13.399999618530273, 14.0, 11.59999942779541, 15.399999618530273, 17.799999237060547, 7.799999713897705, 4.599999904632568, 12.399999618530273, 11.199999809265137] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [17.399999618530273, 15.399999618530273, 32.0, 33.39999771118164, 24.799999237060547, 24.399999618530273, 16.0, 17.0, 22.399999618530273, 20.600000381469727, 28.599998474121094, 28.19999885559082, 21.799999237060547, 19.19999885559082, 15.0, 14.199999809265137] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [24.399999618530273, 22.799999237060547, 20.600000381469727, 19.600000381469727, 23.19999885559082, 22.19999885559082, 23.0, 22.399999618530273, 24.19999885559082, 22.799999237060547, 37.39999771118164, 35.79999923706055, 29.399999618530273, 28.399999618530273, 15.799999237060547, 16.399999618530273] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 19.53+/- 1.075 (max: 37.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 24.27+/- 1.49 (max: 37.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 12.41+/- 1.011 (max: 17.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 21.9+/- 1.542 (max: 33.4) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 12.34+/- 0.2959 (max: 18.35) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 10.98+/- 0.4412 (max: 14.23) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.52+/- 0.4541 (max: 15.25) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 13.51+/- 0.4522 (max: 18.35) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.2285+/- 0.02794 (max: 0.8) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.2969+/- 0.05599 (max: 0.8) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.1+/- 0.01732 (max: 0.22) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.2887+/- 0.04742 (max: 0.66) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 4.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 15.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 4.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 14.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 8.146 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 8.146 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 8.417 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.26 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.09 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.07 | +----------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-LSTM_SEED3 against population in Overcooked-ForcedCoord6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [4.199999809265137, 0.0, 4.400000095367432, 0.3999999761581421, 4.199999809265137, 0.3999999761581421, 2.3999998569488525, 0.0, 3.0, 0.0, 3.0, 0.0, 4.599999904632568, 0.19999998807907104, 3.5999999046325684, 0.7999999523162842, 2.3999998569488525, 1.399999976158142, 4.0, 0.3999999761581421, 4.199999809265137, 0.0, 2.3999998569488525, 0.0, 3.5999999046325684, 0.19999998807907104, 5.199999809265137, 0.19999998807907104, 4.400000095367432, 0.3999999761581421, 2.799999952316284, 0.19999998807907104, 2.200000047683716, 0.0, 1.7999999523162842, 0.3999999761581421, 1.5999999046325684, 0.19999998807907104, 2.0, 0.19999998807907104, 1.7999999523162842, 0.19999998807907104, 2.200000047683716, 0.0, 5.599999904632568, 0.19999998807907104, 5.199999809265137, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [4.199999809265137, 0.0, 2.3999998569488525, 0.0, 4.599999904632568, 0.19999998807907104, 4.0, 0.3999999761581421, 3.5999999046325684, 0.19999998807907104, 2.799999952316284, 0.19999998807907104, 1.5999999046325684, 0.19999998807907104, 2.200000047683716, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [4.400000095367432, 0.3999999761581421, 3.0, 0.0, 3.5999999046325684, 0.7999999523162842, 4.199999809265137, 0.0, 5.199999809265137, 0.19999998807907104, 2.200000047683716, 0.0, 2.0, 0.19999998807907104, 5.599999904632568, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [4.199999809265137, 0.3999999761581421, 3.0, 0.0, 2.3999998569488525, 1.399999976158142, 2.3999998569488525, 0.0, 4.400000095367432, 0.3999999761581421, 1.7999999523162842, 0.3999999761581421, 1.7999999523162842, 0.19999998807907104, 5.199999809265137, 0.0] +---------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.804+/- 0.2608 (max: 5.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.75+/- 0.4307 (max: 5.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.662+/- 0.4323 (max: 4.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 2.0+/- 0.5128 (max: 5.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 4.572+/- 0.4702 (max: 9.319) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 4.618+/- 0.7886 (max: 8.773) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 4.346+/- 0.8285 (max: 9.319) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 4.753+/- 0.8742 (max: 9.217) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.002083+/- 0.000663 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.001875+/- 0.001008 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.001875+/- 0.00136 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0025+/- 0.001118 (max: 0.01) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +---------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-LSTM_SEED3 against population in Overcooked-CounterCircuit6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [4.0, 2.200000047683716, 4.599999904632568, 2.3999998569488525, 0.7999999523162842, 1.0, 1.5999999046325684, 1.1999999284744263, 2.0, 2.0, 0.3999999761581421, 0.0, 4.599999904632568, 2.200000047683716, 13.0, 5.599999904632568, 4.799999713897705, 3.0, 3.5999999046325684, 1.0, 12.199999809265137, 8.800000190734863, 3.0, 0.5999999642372131, 3.3999998569488525, 2.5999999046325684, 8.399999618530273, 5.0, 3.1999998092651367, 1.0, 4.400000095367432, 1.399999976158142, 7.399999618530273, 3.5999999046325684, 3.1999998092651367, 1.5999999046325684, 1.7999999523162842, 1.0, 4.599999904632568, 2.3999998569488525, 2.200000047683716, 0.3999999761581421, 3.799999952316284, 1.1999999284744263, 2.3999998569488525, 1.5999999046325684, 3.799999952316284, 0.19999998807907104] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [4.0, 2.200000047683716, 1.5999999046325684, 1.1999999284744263, 4.599999904632568, 2.200000047683716, 3.5999999046325684, 1.0, 3.3999998569488525, 2.5999999046325684, 4.400000095367432, 1.399999976158142, 1.7999999523162842, 1.0, 3.799999952316284, 1.1999999284744263] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [4.599999904632568, 2.3999998569488525, 2.0, 2.0, 13.0, 5.599999904632568, 12.199999809265137, 8.800000190734863, 8.399999618530273, 5.0, 7.399999618530273, 3.5999999046325684, 4.599999904632568, 2.3999998569488525, 2.3999998569488525, 1.5999999046325684] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [0.7999999523162842, 1.0, 0.3999999761581421, 0.0, 4.799999713897705, 3.0, 3.0, 0.5999999642372131, 3.1999998092651367, 1.0, 3.1999998092651367, 1.5999999046325684, 2.200000047683716, 0.3999999761581421, 3.799999952316284, 0.19999998807907104] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 3.233+/- 0.4047 (max: 13.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 1.825+/- 0.3732 (max: 4.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 2.5+/- 0.3199 (max: 4.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 5.375+/- 0.9079 (max: 13.0) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 6.684+/- 0.3522 (max: 11.96) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 5.113+/- 0.6292 (max: 8.542) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 6.42+/- 0.3651 (max: 8.485) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 8.52+/- 0.4931 (max: 11.96) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.005+/- 0.001518 (max: 0.06) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0125+/- 0.003819 (max: 0.06) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 1.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 1.6 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 4.359 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 5.426 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-LSTM_SEED3 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.19999998807907104, 7.199999809265137, 0.19999998807907104, 16.799999237060547, 0.3999999761581421, 10.399999618530273, 0.3999999761581421, 10.399999618530273, 0.3999999761581421, 42.39999771118164, 0.5999999642372131, 11.0, 0.19999998807907104, 41.39999771118164, 0.19999998807907104, 59.39999771118164, 0.19999998807907104, 19.600000381469727, 0.19999998807907104, 20.600000381469727, 0.19999998807907104, 15.799999237060547, 0.3999999761581421, 10.199999809265137, 0.19999998807907104, 6.199999809265137, 0.19999998807907104, 25.19999885559082, 0.3999999761581421, 12.799999237060547, 0.19999998807907104, 6.799999713897705, 0.19999998807907104, 28.799999237060547, 0.3999999761581421, 10.0, 0.19999998807907104, 6.199999809265137, 0.19999998807907104, 28.19999885559082, 0.3999999761581421, 13.399999618530273, 0.3999999761581421, 9.800000190734863, 0.19999998807907104, 28.19999885559082, 0.19999998807907104, 10.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.19999998807907104, 7.199999809265137, 0.3999999761581421, 10.399999618530273, 0.19999998807907104, 41.39999771118164, 0.19999998807907104, 20.600000381469727, 0.19999998807907104, 6.199999809265137, 0.19999998807907104, 6.799999713897705, 0.19999998807907104, 6.199999809265137, 0.3999999761581421, 9.800000190734863] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.19999998807907104, 16.799999237060547, 0.3999999761581421, 42.39999771118164, 0.19999998807907104, 59.39999771118164, 0.19999998807907104, 15.799999237060547, 0.19999998807907104, 25.19999885559082, 0.19999998807907104, 28.799999237060547, 0.19999998807907104, 28.19999885559082, 0.19999998807907104, 28.19999885559082] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.3999999761581421, 10.399999618530273, 0.5999999642372131, 11.0, 0.19999998807907104, 19.600000381469727, 0.3999999761581421, 10.199999809265137, 0.3999999761581421, 12.799999237060547, 0.3999999761581421, 10.0, 0.3999999761581421, 13.399999618530273, 0.19999998807907104, 10.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 9.533+/- 1.922 (max: 59.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 6.275+/- 1.623 (max: 19.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 6.912+/- 2.703 (max: 41.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 15.41+/- 4.615 (max: 59.4) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 9.404+/- 1.199 (max: 31.81) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 7.925+/- 1.398 (max: 16.73) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 8.098+/- 1.813 (max: 26.57) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 12.19+/- 2.74 (max: 31.81) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.1154+/- 0.02953 (max: 0.84) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.04937+/- 0.01822 (max: 0.26) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.07187+/- 0.04141 (max: 0.64) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.225+/- 0.07024 (max: 0.84) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.2 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 1.99 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating ACCEL_CNN-LSTM_SEED3 against population in Overcooked-CrampedRoom6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [10.59999942779541, 8.199999809265137, 66.79999542236328, 63.79999923706055, 105.39999389648438, 105.19999694824219, 17.600000381469727, 16.399999618530273, 74.4000015258789, 69.4000015258789, 104.79999542236328, 102.0, 19.19999885559082, 17.799999237060547, 94.4000015258789, 93.19999694824219, 104.5999984741211, 103.5999984741211, 20.799999237060547, 19.600000381469727, 80.4000015258789, 80.79999542236328, 91.5999984741211, 88.5999984741211, 9.399999618530273, 10.800000190734863, 68.5999984741211, 71.0, 95.0, 96.5999984741211, 18.399999618530273, 17.799999237060547, 86.4000015258789, 84.0, 100.0, 102.79999542236328, 20.0, 18.19999885559082, 73.19999694824219, 73.0, 98.5999984741211, 103.5999984741211, 17.399999618530273, 16.600000381469727, 84.0, 82.5999984741211, 104.0, 104.39999389648438] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [10.59999942779541, 8.199999809265137, 17.600000381469727, 16.399999618530273, 19.19999885559082, 17.799999237060547, 20.799999237060547, 19.600000381469727, 9.399999618530273, 10.800000190734863, 18.399999618530273, 17.799999237060547, 20.0, 18.19999885559082, 17.399999618530273, 16.600000381469727] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [66.79999542236328, 63.79999923706055, 74.4000015258789, 69.4000015258789, 94.4000015258789, 93.19999694824219, 80.4000015258789, 80.79999542236328, 68.5999984741211, 71.0, 86.4000015258789, 84.0, 73.19999694824219, 73.0, 84.0, 82.5999984741211] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [105.39999389648438, 105.19999694824219, 104.79999542236328, 102.0, 104.5999984741211, 103.5999984741211, 91.5999984741211, 88.5999984741211, 95.0, 96.5999984741211, 100.0, 102.79999542236328, 98.5999984741211, 103.5999984741211, 104.0, 104.39999389648438] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 64.91+/- 5.288 (max: 105.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 100.7+/- 1.299 (max: 105.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 16.17+/- 1.009 (max: 20.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 77.87+/- 2.306 (max: 94.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 20.01+/- 0.7235 (max: 28.98) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 21.78+/- 0.568 (max: 25.34) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.89+/- 0.4176 (max: 15.98) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 24.36+/- 0.6765 (max: 28.98) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.7027+/- 0.05534 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9912+/- 0.002016 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1712+/- 0.02091 (max: 0.3) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9456+/- 0.00532 (max: 0.98) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 8.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 88.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 8.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 63.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.62 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 16.4 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.62 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 18.7 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.98 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.91 | +------------------------------------------------------------------------------------------------- diff --git a/src/run_results_txt/eval_xpid_all_cnn_s5_out.txt b/src/run_results_txt/eval_xpid_all_cnn_s5_out.txt new file mode 100644 index 0000000..16f5ab5 --- /dev/null +++ b/src/run_results_txt/eval_xpid_all_cnn_s5_out.txt @@ -0,0 +1,2290 @@ +Evaluating DR_CNN-S5_SEED1 against population in Overcooked-CoordRing6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + +k eval/a0:test_return:Overcooked-CoordRing6_9, v [3.0, 0.0, 0.3999999761581421, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5999999046325684, 0.0, 0.0, 0.0, 0.0, 0.0, 10.800000190734863, 0.0, 9.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.200000047683716, 0.0, 12.799999237060547, 0.0, 0.0, 0.0, 0.3999999761581421, 0.0, 3.799999952316284, 0.0, 0.0, 0.0, 1.5999999046325684, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [3.0, 0.0, 0.0, 0.0, 1.5999999046325684, 0.0, 10.800000190734863, 0.0, 0.0, 0.0, 2.200000047683716, 0.0, 0.3999999761581421, 0.0, 1.5999999046325684, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [0.3999999761581421, 0.0, 0.0, 0.0, 0.0, 0.0, 9.0, 0.0, 0.0, 0.0, 12.799999237060547, 0.0, 3.799999952316284, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 0.95+/- 0.3949 (max: 12.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 1.225+/- 0.6824 (max: 10.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 1.625+/- 0.9513 (max: 12.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 1.459+/- 0.4567 (max: 11.5) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 2.372+/- 0.8713 (max: 10.36) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 2.006+/- 0.992 (max: 11.5) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.001458+/- 0.001073 (max: 0.05) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.003125+/- 0.003125 (max: 0.05) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +-------------------------------------------------------------------------------------------------- +Evaluating DR_CNN-S5_SEED1 against population in Overcooked-ForcedCoord6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +----------------------------------------------------------------------------------------- +Evaluating DR_CNN-S5_SEED1 against population in Overcooked-CounterCircuit6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.004167+/- 0.004167 (max: 0.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0125+/- 0.0125 (max: 0.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.04146+/- 0.04146 (max: 1.99) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.1244+/- 0.1244 (max: 1.99) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating DR_CNN-S5_SEED1 against population in Overcooked-AsymmAdvantages6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 1.0, 0.0, 6.399999618530273, 0.0, 0.0, 0.0, 4.599999904632568, 0.0, 44.20000076293945, 0.0, 0.19999998807907104, 0.0, 32.0, 0.0, 60.39999771118164, 0.0, 9.199999809265137, 0.0, 16.399999618530273, 0.0, 7.0, 0.0, 0.19999998807907104, 0.0, 0.5999999642372131, 0.0, 14.59999942779541, 0.0, 0.7999999523162842, 0.0, 1.7999999523162842, 0.0, 26.599998474121094, 0.0, 0.0, 0.0, 0.7999999523162842, 0.0, 15.59999942779541, 0.0, 1.1999999284744263, 0.0, 2.3999998569488525, 0.0, 18.799999237060547, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 1.0, 0.0, 4.599999904632568, 0.0, 32.0, 0.0, 16.399999618530273, 0.0, 0.5999999642372131, 0.0, 1.7999999523162842, 0.0, 0.7999999523162842, 0.0, 2.3999998569488525] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 6.399999618530273, 0.0, 44.20000076293945, 0.0, 60.39999771118164, 0.0, 7.0, 0.0, 14.59999942779541, 0.0, 26.599998474121094, 0.0, 15.59999942779541, 0.0, 18.799999237060547] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 9.199999809265137, 0.0, 0.19999998807907104, 0.0, 0.7999999523162842, 0.0, 0.0, 0.0, 1.1999999284744263, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 5.517+/- 1.782 (max: 60.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.725+/- 0.5715 (max: 9.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 3.725+/- 2.144 (max: 32.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 12.1+/- 4.503 (max: 60.4) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 5.472+/- 1.273 (max: 35.1) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 1.414+/- 0.6869 (max: 9.968) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 4.819+/- 1.914 (max: 27.71) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 10.18+/- 2.912 (max: 35.1) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.06771+/- 0.02541 (max: 0.76) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.04187+/- 0.03017 (max: 0.44) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1612+/- 0.06501 (max: 0.76) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating DR_CNN-S5_SEED1 against population in Overcooked-CrampedRoom6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [7.799999713897705, 0.0, 95.5999984741211, 0.0, 65.79999542236328, 0.0, 15.0, 0.0, 69.4000015258789, 0.0, 19.0, 0.0, 14.399999618530273, 0.0, 98.79999542236328, 0.0, 6.599999904632568, 0.0, 17.0, 0.0, 89.0, 0.0, 21.19999885559082, 0.0, 8.800000190734863, 0.0, 97.0, 0.0, 23.399999618530273, 0.0, 16.799999237060547, 0.0, 59.79999923706055, 0.0, 12.799999237060547, 0.0, 14.199999809265137, 0.0, 105.19999694824219, 0.0, 13.0, 0.0, 21.0, 0.0, 90.79999542236328, 0.0, 10.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [7.799999713897705, 0.0, 15.0, 0.0, 14.399999618530273, 0.0, 17.0, 0.0, 8.800000190734863, 0.0, 16.799999237060547, 0.0, 14.199999809265137, 0.0, 21.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [95.5999984741211, 0.0, 69.4000015258789, 0.0, 98.79999542236328, 0.0, 89.0, 0.0, 97.0, 0.0, 59.79999923706055, 0.0, 105.19999694824219, 0.0, 90.79999542236328, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [65.79999542236328, 0.0, 19.0, 0.0, 6.599999904632568, 0.0, 21.19999885559082, 0.0, 23.399999618530273, 0.0, 12.799999237060547, 0.0, 13.0, 0.0, 10.0, 0.0] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 20.67+/- 4.772 (max: 105.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 10.74+/- 4.242 (max: 65.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 7.187+/- 1.998 (max: 21.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 44.1+/- 11.69 (max: 105.2) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 11.87+/- 2.129 (max: 47.28) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 9.75+/- 2.695 (max: 31.6) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 6.978+/- 1.844 (max: 17.52) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 18.89+/- 5.173 (max: 47.28) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.2175+/- 0.04867 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.1394+/- 0.05959 (max: 0.94) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.07687+/- 0.02623 (max: 0.33) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.4362+/- 0.1145 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +-------------------------------------------------------------------------------------------------- +Evaluating DR_CNN-S5_SEED2 against population in Overcooked-CoordRing6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + +k eval/a0:test_return:Overcooked-CoordRing6_9, v [5.599999904632568, 6.0, 20.399999618530273, 19.799999237060547, 16.799999237060547, 17.19999885559082, 14.799999237060547, 15.799999237060547, 28.799999237060547, 33.79999923706055, 17.0, 18.399999618530273, 12.799999237060547, 13.0, 26.399999618530273, 23.600000381469727, 17.0, 18.399999618530273, 14.799999237060547, 13.59999942779541, 19.799999237060547, 19.19999885559082, 13.59999942779541, 16.19999885559082, 15.199999809265137, 14.0, 21.0, 20.19999885559082, 17.600000381469727, 20.799999237060547, 16.0, 14.399999618530273, 28.0, 27.399999618530273, 31.399999618530273, 32.39999771118164, 3.799999952316284, 2.799999952316284, 20.399999618530273, 20.399999618530273, 22.799999237060547, 26.399999618530273, 11.399999618530273, 12.799999237060547, 16.399999618530273, 15.59999942779541, 15.59999942779541, 17.600000381469727] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [5.599999904632568, 6.0, 14.799999237060547, 15.799999237060547, 12.799999237060547, 13.0, 14.799999237060547, 13.59999942779541, 15.199999809265137, 14.0, 16.0, 14.399999618530273, 3.799999952316284, 2.799999952316284, 11.399999618530273, 12.799999237060547] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [20.399999618530273, 19.799999237060547, 28.799999237060547, 33.79999923706055, 26.399999618530273, 23.600000381469727, 19.799999237060547, 19.19999885559082, 21.0, 20.19999885559082, 28.0, 27.399999618530273, 20.399999618530273, 20.399999618530273, 16.399999618530273, 15.59999942779541] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [16.799999237060547, 17.19999885559082, 17.0, 18.399999618530273, 17.0, 18.399999618530273, 13.59999942779541, 16.19999885559082, 17.600000381469727, 20.799999237060547, 31.399999618530273, 32.39999771118164, 22.799999237060547, 26.399999618530273, 15.59999942779541, 17.600000381469727] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 18.07+/- 0.9793 (max: 33.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 19.95+/- 1.385 (max: 32.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 11.67+/- 1.115 (max: 16.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 22.57+/- 1.24 (max: 33.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 12.57+/- 0.3182 (max: 17.95) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 12.31+/- 0.6169 (max: 17.95) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.03+/- 0.5932 (max: 15.05) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 13.38+/- 0.3891 (max: 17.66) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.184+/- 0.02202 (max: 0.6) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.1869+/- 0.04031 (max: 0.56) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.08437+/- 0.01557 (max: 0.2) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.2806+/- 0.03748 (max: 0.6) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 2.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 13.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 15.6 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 6.94 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 9.666 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.94 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.16 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.06 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.07 | +----------------------------------------------------------------------------------------------- +Evaluating DR_CNN-S5_SEED2 against population in Overcooked-ForcedCoord6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +----------------------------------------------------------------------------------------- +Evaluating DR_CNN-S5_SEED2 against population in Overcooked-CounterCircuit6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [6.0, 2.799999952316284, 10.399999618530273, 6.199999809265137, 7.599999904632568, 1.399999976158142, 3.799999952316284, 2.3999998569488525, 5.0, 3.5999999046325684, 4.199999809265137, 0.0, 2.3999998569488525, 1.1999999284744263, 11.0, 4.799999713897705, 6.199999809265137, 1.7999999523162842, 4.199999809265137, 3.0, 13.199999809265137, 10.800000190734863, 10.0, 2.3999998569488525, 4.599999904632568, 1.7999999523162842, 10.800000190734863, 6.799999713897705, 8.800000190734863, 1.7999999523162842, 8.399999618530273, 1.5999999046325684, 10.0, 4.199999809265137, 12.0, 6.399999618530273, 3.799999952316284, 2.200000047683716, 10.800000190734863, 8.0, 7.399999618530273, 1.1999999284744263, 3.1999998092651367, 2.200000047683716, 5.799999713897705, 6.799999713897705, 7.599999904632568, 0.19999998807907104] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [6.0, 2.799999952316284, 3.799999952316284, 2.3999998569488525, 2.3999998569488525, 1.1999999284744263, 4.199999809265137, 3.0, 4.599999904632568, 1.7999999523162842, 8.399999618530273, 1.5999999046325684, 3.799999952316284, 2.200000047683716, 3.1999998092651367, 2.200000047683716] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [10.399999618530273, 6.199999809265137, 5.0, 3.5999999046325684, 11.0, 4.799999713897705, 13.199999809265137, 10.800000190734863, 10.800000190734863, 6.799999713897705, 10.0, 4.199999809265137, 10.800000190734863, 8.0, 5.799999713897705, 6.799999713897705] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [7.599999904632568, 1.399999976158142, 4.199999809265137, 0.0, 6.199999809265137, 1.7999999523162842, 10.0, 2.3999998569488525, 8.800000190734863, 1.7999999523162842, 12.0, 6.399999618530273, 7.399999618530273, 1.1999999284744263, 7.599999904632568, 0.19999998807907104] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 5.433+/- 0.506 (max: 13.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 4.937+/- 0.9451 (max: 12.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 3.35+/- 0.4573 (max: 8.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 8.012+/- 0.7453 (max: 13.2) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 8.302+/- 0.3819 (max: 13.03) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 7.307+/- 0.7994 (max: 10.58) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 7.157+/- 0.3573 (max: 10.27) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 10.44+/- 0.3814 (max: 13.03) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.01042+/- 0.002962 (max: 0.1) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0025+/- 0.001443 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.02812+/- 0.006965 (max: 0.1) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 1.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 3.6 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 4.75 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 8.146 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating DR_CNN-S5_SEED2 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 0.7999999523162842, 0.0, 8.199999809265137, 0.0, 0.0, 0.0, 3.799999952316284, 0.0, 42.599998474121094, 0.0, 0.19999998807907104, 0.0, 30.19999885559082, 0.0, 50.39999771118164, 0.0, 8.59999942779541, 0.0, 13.799999237060547, 0.0, 5.599999904632568, 0.0, 0.5999999642372131, 0.0, 1.1999999284744263, 0.0, 16.799999237060547, 0.0, 0.3999999761581421, 0.0, 0.7999999523162842, 0.0, 25.399999618530273, 0.0, 0.19999998807907104, 0.0, 1.5999999046325684, 0.0, 13.199999809265137, 0.0, 2.5999999046325684, 0.0, 2.3999998569488525, 0.0, 19.799999237060547, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 0.7999999523162842, 0.0, 3.799999952316284, 0.0, 30.19999885559082, 0.0, 13.799999237060547, 0.0, 1.1999999284744263, 0.0, 0.7999999523162842, 0.0, 1.5999999046325684, 0.0, 2.3999998569488525] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 8.199999809265137, 0.0, 42.599998474121094, 0.0, 50.39999771118164, 0.0, 5.599999904632568, 0.0, 16.799999237060547, 0.0, 25.399999618530273, 0.0, 13.199999809265137, 0.0, 19.799999237060547] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 8.59999942779541, 0.0, 0.5999999642372131, 0.0, 0.3999999761581421, 0.0, 0.19999998807907104, 0.0, 2.5999999046325684, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 5.192+/- 1.612 (max: 50.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.7875+/- 0.5454 (max: 8.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 3.412+/- 1.981 (max: 30.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 11.37+/- 4.019 (max: 50.4) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 5.41+/- 1.211 (max: 31.68) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 1.676+/- 0.7238 (max: 9.902) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 5.024+/- 1.983 (max: 27.2) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 9.528+/- 2.696 (max: 31.68) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.06604+/- 0.02441 (max: 0.76) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.03875+/- 0.0274 (max: 0.4) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1594+/- 0.0627 (max: 0.76) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating DR_CNN-S5_SEED2 against population in Overcooked-CrampedRoom6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 + +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [9.199999809265137, 8.59999942779541, 67.79999542236328, 66.79999542236328, 103.19999694824219, 105.0, 15.399999618530273, 14.799999237060547, 76.5999984741211, 75.19999694824219, 102.5999984741211, 107.0, 20.0, 18.19999885559082, 95.4000015258789, 95.4000015258789, 100.5999984741211, 101.4000015258789, 17.799999237060547, 17.0, 89.4000015258789, 87.19999694824219, 90.19999694824219, 90.19999694824219, 9.59999942779541, 7.0, 76.4000015258789, 75.19999694824219, 100.0, 100.79999542236328, 18.19999885559082, 17.399999618530273, 85.79999542236328, 82.5999984741211, 102.0, 104.79999542236328, 17.600000381469727, 16.600000381469727, 84.19999694824219, 83.0, 100.4000015258789, 103.0, 21.19999885559082, 22.399999618530273, 86.0, 86.79999542236328, 100.79999542236328, 97.5999984741211] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [9.199999809265137, 8.59999942779541, 15.399999618530273, 14.799999237060547, 20.0, 18.19999885559082, 17.799999237060547, 17.0, 9.59999942779541, 7.0, 18.19999885559082, 17.399999618530273, 17.600000381469727, 16.600000381469727, 21.19999885559082, 22.399999618530273] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [67.79999542236328, 66.79999542236328, 76.5999984741211, 75.19999694824219, 95.4000015258789, 95.4000015258789, 89.4000015258789, 87.19999694824219, 76.4000015258789, 75.19999694824219, 85.79999542236328, 82.5999984741211, 84.19999694824219, 83.0, 86.0, 86.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [103.19999694824219, 105.0, 102.5999984741211, 107.0, 100.5999984741211, 101.4000015258789, 90.19999694824219, 90.19999694824219, 100.0, 100.79999542236328, 102.0, 104.79999542236328, 100.4000015258789, 103.0, 100.79999542236328, 97.5999984741211] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 66.13+/- 5.39 (max: 107.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 100.6+/- 1.16 (max: 107.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 15.69+/- 1.165 (max: 22.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 82.11+/- 2.127 (max: 95.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 18.98+/- 0.7133 (max: 32.68) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 21.69+/- 1.192 (max: 32.68) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.78+/- 0.3937 (max: 15.49) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 21.47+/- 0.7079 (max: 26.1) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.7112+/- 0.05656 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.99+/- 0.005083 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1675+/- 0.02118 (max: 0.32) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9762+/- 0.004366 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 7.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 90.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 7.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 66.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.34 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 17.46 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.34 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 16.84 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.94 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.95 | +------------------------------------------------------------------------------------------------- +Evaluating DR_CNN-S5_SEED3 against population in Overcooked-CoordRing6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [3.0, 0.0, 0.3999999761581421, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5999999046325684, 0.0, 0.0, 0.0, 0.0, 0.0, 10.800000190734863, 0.0, 9.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 2.200000047683716, 0.0, 12.799999237060547, 0.0, 0.0, 0.0, 0.3999999761581421, 0.0, 3.799999952316284, 0.0, 0.0, 0.0, 1.5999999046325684, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [3.0, 0.0, 0.0, 0.0, 1.5999999046325684, 0.0, 10.800000190734863, 0.0, 0.0, 0.0, 2.200000047683716, 0.0, 0.3999999761581421, 0.0, 1.5999999046325684, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [0.3999999761581421, 0.0, 0.0, 0.0, 0.0, 0.0, 9.0, 0.0, 0.0, 0.0, 12.799999237060547, 0.0, 3.799999952316284, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 0.95+/- 0.3949 (max: 12.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 1.225+/- 0.6824 (max: 10.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 1.625+/- 0.9513 (max: 12.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 1.459+/- 0.4567 (max: 11.5) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 2.372+/- 0.8713 (max: 10.36) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 2.006+/- 0.992 (max: 11.5) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.001458+/- 0.001073 (max: 0.05) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.003125+/- 0.003125 (max: 0.05) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +-------------------------------------------------------------------------------------------------- +Evaluating DR_CNN-S5_SEED3 against population in Overcooked-ForcedCoord6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +----------------------------------------------------------------------------------------- +Evaluating DR_CNN-S5_SEED3 against population in Overcooked-CounterCircuit6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.004167+/- 0.004167 (max: 0.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0125+/- 0.0125 (max: 0.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.04146+/- 0.04146 (max: 1.99) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.1244+/- 0.1244 (max: 1.99) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating DR_CNN-S5_SEED3 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 1.0, 0.0, 6.399999618530273, 0.0, 0.0, 0.0, 4.599999904632568, 0.0, 44.20000076293945, 0.0, 0.19999998807907104, 0.0, 32.0, 0.0, 60.39999771118164, 0.0, 9.199999809265137, 0.0, 16.399999618530273, 0.0, 7.0, 0.0, 0.19999998807907104, 0.0, 0.5999999642372131, 0.0, 14.59999942779541, 0.0, 0.7999999523162842, 0.0, 1.7999999523162842, 0.0, 26.599998474121094, 0.0, 0.0, 0.0, 0.7999999523162842, 0.0, 15.59999942779541, 0.0, 1.1999999284744263, 0.0, 2.3999998569488525, 0.0, 18.799999237060547, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 1.0, 0.0, 4.599999904632568, 0.0, 32.0, 0.0, 16.399999618530273, 0.0, 0.5999999642372131, 0.0, 1.7999999523162842, 0.0, 0.7999999523162842, 0.0, 2.3999998569488525] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 6.399999618530273, 0.0, 44.20000076293945, 0.0, 60.39999771118164, 0.0, 7.0, 0.0, 14.59999942779541, 0.0, 26.599998474121094, 0.0, 15.59999942779541, 0.0, 18.799999237060547] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 9.199999809265137, 0.0, 0.19999998807907104, 0.0, 0.7999999523162842, 0.0, 0.0, 0.0, 1.1999999284744263, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 5.517+/- 1.782 (max: 60.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.725+/- 0.5715 (max: 9.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 3.725+/- 2.144 (max: 32.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 12.1+/- 4.503 (max: 60.4) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 5.472+/- 1.273 (max: 35.1) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 1.414+/- 0.6869 (max: 9.968) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 4.819+/- 1.914 (max: 27.71) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 10.18+/- 2.912 (max: 35.1) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.06771+/- 0.02541 (max: 0.76) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.04187+/- 0.03017 (max: 0.44) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1612+/- 0.06501 (max: 0.76) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating DR_CNN-S5_SEED3 against population in Overcooked-CrampedRoom6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [7.799999713897705, 0.0, 95.5999984741211, 0.0, 65.79999542236328, 0.0, 15.0, 0.0, 69.4000015258789, 0.0, 19.0, 0.0, 14.399999618530273, 0.0, 98.79999542236328, 0.0, 6.599999904632568, 0.0, 17.0, 0.0, 89.0, 0.0, 21.19999885559082, 0.0, 8.800000190734863, 0.0, 97.0, 0.0, 23.399999618530273, 0.0, 16.799999237060547, 0.0, 59.79999923706055, 0.0, 12.799999237060547, 0.0, 14.199999809265137, 0.0, 105.19999694824219, 0.0, 13.0, 0.0, 21.0, 0.0, 90.79999542236328, 0.0, 10.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [7.799999713897705, 0.0, 15.0, 0.0, 14.399999618530273, 0.0, 17.0, 0.0, 8.800000190734863, 0.0, 16.799999237060547, 0.0, 14.199999809265137, 0.0, 21.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [95.5999984741211, 0.0, 69.4000015258789, 0.0, 98.79999542236328, 0.0, 89.0, 0.0, 97.0, 0.0, 59.79999923706055, 0.0, 105.19999694824219, 0.0, 90.79999542236328, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [65.79999542236328, 0.0, 19.0, 0.0, 6.599999904632568, 0.0, 21.19999885559082, 0.0, 23.399999618530273, 0.0, 12.799999237060547, 0.0, 13.0, 0.0, 10.0, 0.0] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 20.67+/- 4.772 (max: 105.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 10.74+/- 4.242 (max: 65.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 7.187+/- 1.998 (max: 21.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 44.1+/- 11.69 (max: 105.2) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 11.87+/- 2.129 (max: 47.28) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 9.75+/- 2.695 (max: 31.6) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 6.978+/- 1.844 (max: 17.52) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 18.89+/- 5.173 (max: 47.28) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.2175+/- 0.04867 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.1394+/- 0.05959 (max: 0.94) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.07687+/- 0.02623 (max: 0.33) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.4362+/- 0.1145 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +-------------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-S5_SEED1 against population in Overcooked-CoordRing6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [6.599999904632568, 5.0, 16.19999885559082, 19.0, 23.0, 22.399999618530273, 16.0, 14.0, 31.599998474121094, 34.0, 21.0, 21.600000381469727, 12.199999809265137, 13.799999237060547, 23.399999618530273, 23.0, 23.19999885559082, 23.0, 14.799999237060547, 10.399999618530273, 17.19999885559082, 16.600000381469727, 20.19999885559082, 21.799999237060547, 17.0, 16.399999618530273, 22.399999618530273, 21.0, 23.19999885559082, 24.600000381469727, 15.59999942779541, 17.0, 26.399999618530273, 30.0, 36.0, 36.0, 6.599999904632568, 6.599999904632568, 21.799999237060547, 20.799999237060547, 31.599998474121094, 31.799999237060547, 15.0, 15.399999618530273, 21.19999885559082, 19.19999885559082, 22.399999618530273, 23.600000381469727] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [6.599999904632568, 5.0, 16.0, 14.0, 12.199999809265137, 13.799999237060547, 14.799999237060547, 10.399999618530273, 17.0, 16.399999618530273, 15.59999942779541, 17.0, 6.599999904632568, 6.599999904632568, 15.0, 15.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [16.19999885559082, 19.0, 31.599998474121094, 34.0, 23.399999618530273, 23.0, 17.19999885559082, 16.600000381469727, 22.399999618530273, 21.0, 26.399999618530273, 30.0, 21.799999237060547, 20.799999237060547, 21.19999885559082, 19.19999885559082] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [23.0, 22.399999618530273, 21.0, 21.600000381469727, 23.19999885559082, 23.0, 20.19999885559082, 21.799999237060547, 23.19999885559082, 24.600000381469727, 36.0, 36.0, 31.599998474121094, 31.799999237060547, 22.399999618530273, 23.600000381469727] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CoordRing6_9 | 20.24+/- 1.062 (max: 36.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 25.34+/- 1.324 (max: 36.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 12.65+/- 1.053 (max: 17.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 22.74+/- 1.322 (max: 34.0) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 12.17+/- 0.242 (max: 16.22) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.37+/- 0.4545 (max: 14.45) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.17+/- 0.4191 (max: 14.86) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 12.95+/- 0.2916 (max: 16.22) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.234+/- 0.02608 (max: 0.72) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.3194+/- 0.04962 (max: 0.72) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.09187+/- 0.01412 (max: 0.18) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.2906+/- 0.04108 (max: 0.69) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 5.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 20.2 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 5.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 16.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 8.818 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 8.818 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 9.11 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.6 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.15 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.12 | +------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-S5_SEED1 against population in Overcooked-ForcedCoord6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [4.400000095367432, 0.0, 6.0, 0.19999998807907104, 3.799999952316284, 0.7999999523162842, 3.0, 0.0, 3.799999952316284, 0.3999999761581421, 3.5999999046325684, 0.0, 4.199999809265137, 0.19999998807907104, 5.599999904632568, 0.5999999642372131, 2.3999998569488525, 0.7999999523162842, 4.400000095367432, 0.19999998807907104, 4.199999809265137, 0.0, 2.5999999046325684, 0.0, 3.1999998092651367, 0.0, 5.199999809265137, 0.0, 5.199999809265137, 0.0, 3.1999998092651367, 0.19999998807907104, 3.3999998569488525, 0.0, 1.5999999046325684, 0.0, 2.5999999046325684, 0.0, 2.200000047683716, 0.0, 1.1999999284744263, 0.0, 5.400000095367432, 0.0, 6.199999809265137, 0.0, 6.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [4.400000095367432, 0.0, 3.0, 0.0, 4.199999809265137, 0.19999998807907104, 4.400000095367432, 0.19999998807907104, 3.1999998092651367, 0.0, 3.1999998092651367, 0.19999998807907104, 2.5999999046325684, 0.0, 5.400000095367432, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [6.0, 0.19999998807907104, 3.799999952316284, 0.3999999761581421, 5.599999904632568, 0.5999999642372131, 4.199999809265137, 0.0, 5.199999809265137, 0.0, 3.3999998569488525, 0.0, 2.200000047683716, 0.0, 6.199999809265137, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [3.799999952316284, 0.7999999523162842, 3.5999999046325684, 0.0, 2.3999998569488525, 0.7999999523162842, 2.5999999046325684, 0.0, 5.199999809265137, 0.0, 1.5999999046325684, 0.0, 1.1999999284744263, 0.0, 6.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 2.017+/- 0.31 (max: 6.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.75+/- 0.4968 (max: 6.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.937+/- 0.5075 (max: 5.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 2.362+/- 0.6211 (max: 6.2) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 4.349+/- 0.5361 (max: 10.08) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 4.044+/- 0.8899 (max: 9.165) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 4.286+/- 0.9403 (max: 8.879) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 4.716+/- 1.005 (max: 10.08) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0008333+/- 0.0005012 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.001875+/- 0.00136 (max: 0.02) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-S5_SEED1 against population in Overcooked-CounterCircuit6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [5.0, 2.799999952316284, 7.399999618530273, 7.599999904632568, 1.5999999046325684, 2.5999999046325684, 3.3999998569488525, 3.0, 5.199999809265137, 4.0, 0.5999999642372131, 0.0, 4.799999713897705, 2.3999998569488525, 13.799999237060547, 5.400000095367432, 9.0, 4.199999809265137, 5.599999904632568, 2.200000047683716, 15.799999237060547, 12.199999809265137, 4.799999713897705, 1.1999999284744263, 4.400000095367432, 3.5999999046325684, 12.199999809265137, 11.399999618530273, 4.599999904632568, 0.5999999642372131, 7.0, 2.5999999046325684, 9.199999809265137, 5.400000095367432, 5.400000095367432, 2.200000047683716, 4.0, 2.3999998569488525, 7.199999809265137, 6.199999809265137, 2.799999952316284, 0.3999999761581421, 5.199999809265137, 1.7999999523162842, 7.199999809265137, 8.800000190734863, 4.599999904632568, 0.3999999761581421] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [5.0, 2.799999952316284, 3.3999998569488525, 3.0, 4.799999713897705, 2.3999998569488525, 5.599999904632568, 2.200000047683716, 4.400000095367432, 3.5999999046325684, 7.0, 2.5999999046325684, 4.0, 2.3999998569488525, 5.199999809265137, 1.7999999523162842] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [7.399999618530273, 7.599999904632568, 5.199999809265137, 4.0, 13.799999237060547, 5.400000095367432, 15.799999237060547, 12.199999809265137, 12.199999809265137, 11.399999618530273, 9.199999809265137, 5.400000095367432, 7.199999809265137, 6.199999809265137, 7.199999809265137, 8.800000190734863] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [1.5999999046325684, 2.5999999046325684, 0.5999999642372131, 0.0, 9.0, 4.199999809265137, 4.799999713897705, 1.1999999284744263, 4.599999904632568, 0.5999999642372131, 5.400000095367432, 2.200000047683716, 2.799999952316284, 0.3999999761581421, 4.599999904632568, 0.3999999761581421] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 5.087+/- 0.5241 (max: 15.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 2.812+/- 0.616 (max: 9.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 3.762+/- 0.3652 (max: 7.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 8.687+/- 0.8618 (max: 15.8) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 8.045+/- 0.3831 (max: 12.11) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 5.964+/- 0.7107 (max: 9.95) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 7.664+/- 0.3013 (max: 9.95) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 10.51+/- 0.2844 (max: 12.11) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.01083+/- 0.002627 (max: 0.06) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.03062+/- 0.004956 (max: 0.06) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 1.8 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 4.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 5.724 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 8.485 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-S5_SEED1 against population in Overcooked-AsymmAdvantages6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.3999999761581421, 8.199999809265137, 0.0, 21.19999885559082, 0.19999998807907104, 10.0, 0.0, 10.199999809265137, 0.0, 45.79999923706055, 0.19999998807907104, 6.799999713897705, 0.0, 44.79999923706055, 0.5999999642372131, 63.39999771118164, 0.3999999761581421, 19.399999618530273, 0.0, 25.599998474121094, 0.0, 16.600000381469727, 0.0, 12.199999809265137, 0.0, 8.59999942779541, 0.0, 30.399999618530273, 0.0, 13.199999809265137, 0.0, 9.59999942779541, 0.0, 35.79999923706055, 0.3999999761581421, 7.599999904632568, 0.0, 7.799999713897705, 0.0, 30.599998474121094, 0.19999998807907104, 14.399999618530273, 0.19999998807907104, 9.59999942779541, 0.0, 27.0, 0.19999998807907104, 6.799999713897705] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.3999999761581421, 8.199999809265137, 0.0, 10.199999809265137, 0.0, 44.79999923706055, 0.0, 25.599998474121094, 0.0, 8.59999942779541, 0.0, 9.59999942779541, 0.0, 7.799999713897705, 0.19999998807907104, 9.59999942779541] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 21.19999885559082, 0.0, 45.79999923706055, 0.5999999642372131, 63.39999771118164, 0.0, 16.600000381469727, 0.0, 30.399999618530273, 0.0, 35.79999923706055, 0.0, 30.599998474121094, 0.0, 27.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.19999998807907104, 10.0, 0.19999998807907104, 6.799999713897705, 0.3999999761581421, 19.399999618530273, 0.0, 12.199999809265137, 0.0, 13.199999809265137, 0.3999999761581421, 7.599999904632568, 0.19999998807907104, 14.399999618530273, 0.19999998807907104, 6.799999713897705] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 10.17+/- 2.111 (max: 63.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 5.75+/- 1.618 (max: 19.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 7.812+/- 3.011 (max: 44.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 16.96+/- 5.046 (max: 63.4) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 8.741+/- 1.279 (max: 28.15) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 7.385+/- 1.53 (max: 15.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 7.799+/- 2.103 (max: 27.87) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 11.04+/- 2.843 (max: 28.15) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.13+/- 0.03234 (max: 0.89) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.04687+/- 0.01736 (max: 0.23) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.08937+/- 0.0458 (max: 0.69) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.2537+/- 0.07644 (max: 0.89) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-S5_SEED1 against population in Overcooked-CrampedRoom6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [8.199999809265137, 6.799999713897705, 70.0, 70.0, 102.79999542236328, 103.5999984741211, 13.399999618530273, 13.399999618530273, 79.0, 70.0, 105.19999694824219, 108.79999542236328, 16.799999237060547, 15.199999809265137, 97.4000015258789, 95.79999542236328, 112.39999389648438, 110.0, 16.0, 15.0, 89.5999984741211, 80.19999694824219, 82.19999694824219, 84.79999542236328, 6.399999618530273, 7.0, 66.0, 68.19999694824219, 99.4000015258789, 107.39999389648438, 15.799999237060547, 16.799999237060547, 94.0, 87.5999984741211, 109.39999389648438, 104.79999542236328, 15.59999942779541, 15.59999942779541, 84.79999542236328, 76.4000015258789, 105.5999984741211, 104.0, 14.799999237060547, 15.199999809265137, 92.79999542236328, 86.0, 111.5999984741211, 108.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [8.199999809265137, 6.799999713897705, 13.399999618530273, 13.399999618530273, 16.799999237060547, 15.199999809265137, 16.0, 15.0, 6.399999618530273, 7.0, 15.799999237060547, 16.799999237060547, 15.59999942779541, 15.59999942779541, 14.799999237060547, 15.199999809265137] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [70.0, 70.0, 79.0, 70.0, 97.4000015258789, 95.79999542236328, 89.5999984741211, 80.19999694824219, 66.0, 68.19999694824219, 94.0, 87.5999984741211, 84.79999542236328, 76.4000015258789, 92.79999542236328, 86.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [102.79999542236328, 103.5999984741211, 105.19999694824219, 108.79999542236328, 112.39999389648438, 110.0, 82.19999694824219, 84.79999542236328, 99.4000015258789, 107.39999389648438, 109.39999389648438, 104.79999542236328, 105.5999984741211, 104.0, 111.5999984741211, 108.79999542236328] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 66.26+/- 5.743 (max: 112.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 103.8+/- 2.162 (max: 112.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 13.25+/- 0.9498 (max: 16.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 81.74+/- 2.679 (max: 97.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 20.45+/- 0.9273 (max: 32.84) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 21.97+/- 0.8477 (max: 29.82) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.05+/- 0.4577 (max: 15.31) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 26.32+/- 1.025 (max: 32.84) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6823+/- 0.05864 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9875+/- 0.005809 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1181+/- 0.01651 (max: 0.18) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9412+/- 0.01103 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 6.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 82.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 6.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 66.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 9.887 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 18.59 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 9.887 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 20.3 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.01 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.93 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.01 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.85 | +------------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-S5_SEED2 against population in Overcooked-CoordRing6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [6.799999713897705, 5.400000095367432, 17.0, 18.19999885559082, 22.0, 21.19999885559082, 14.59999942779541, 13.0, 32.0, 30.599998474121094, 21.399999618530273, 21.600000381469727, 11.199999809265137, 12.59999942779541, 24.399999618530273, 24.0, 22.799999237060547, 22.600000381469727, 16.600000381469727, 11.800000190734863, 14.799999237060547, 15.199999809265137, 18.600000381469727, 19.399999618530273, 14.399999618530273, 13.59999942779541, 19.19999885559082, 17.0, 22.0, 23.399999618530273, 13.59999942779541, 14.199999809265137, 22.600000381469727, 24.799999237060547, 34.39999771118164, 34.79999923706055, 3.799999952316284, 5.0, 15.799999237060547, 17.600000381469727, 24.799999237060547, 25.0, 13.0, 12.0, 10.199999809265137, 12.199999809265137, 9.800000190734863, 10.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [6.799999713897705, 5.400000095367432, 14.59999942779541, 13.0, 11.199999809265137, 12.59999942779541, 16.600000381469727, 11.800000190734863, 14.399999618530273, 13.59999942779541, 13.59999942779541, 14.199999809265137, 3.799999952316284, 5.0, 13.0, 12.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [17.0, 18.19999885559082, 32.0, 30.599998474121094, 24.399999618530273, 24.0, 14.799999237060547, 15.199999809265137, 19.19999885559082, 17.0, 22.600000381469727, 24.799999237060547, 15.799999237060547, 17.600000381469727, 10.199999809265137, 12.199999809265137] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [22.0, 21.19999885559082, 21.399999618530273, 21.600000381469727, 22.799999237060547, 22.600000381469727, 18.600000381469727, 19.399999618530273, 22.0, 23.399999618530273, 34.39999771118164, 34.79999923706055, 24.799999237060547, 25.0, 9.800000190734863, 10.399999618530273] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CoordRing6_9 | 17.74+/- 1.045 (max: 34.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 22.14+/- 1.633 (max: 34.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 11.35+/- 0.972 (max: 16.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 19.72+/- 1.545 (max: 32.0) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 11.87+/- 0.265 (max: 14.73) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 10.88+/- 0.4691 (max: 14.16) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 11.98+/- 0.4862 (max: 14.17) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 12.76+/- 0.2903 (max: 14.73) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.1806+/- 0.02446 (max: 0.73) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.24+/- 0.05151 (max: 0.73) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.07625+/- 0.01217 (max: 0.13) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.2256+/- 0.04152 (max: 0.58) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 3.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 9.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 3.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 10.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 7.618 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 7.618 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 7.846 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.28 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.04 | +------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-S5_SEED2 against population in Overcooked-ForcedCoord6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [3.1999998092651367, 0.19999998807907104, 2.799999952316284, 0.7999999523162842, 2.799999952316284, 2.0, 2.799999952316284, 0.7999999523162842, 1.7999999523162842, 0.7999999523162842, 1.1999999284744263, 0.7999999523162842, 1.0, 0.5999999642372131, 2.3999998569488525, 0.7999999523162842, 2.3999998569488525, 1.399999976158142, 1.5999999046325684, 1.1999999284744263, 2.200000047683716, 1.399999976158142, 0.3999999761581421, 1.399999976158142, 2.3999998569488525, 1.0, 4.199999809265137, 1.7999999523162842, 1.5999999046325684, 0.5999999642372131, 1.0, 0.19999998807907104, 0.5999999642372131, 0.19999998807907104, 1.0, 0.5999999642372131, 1.1999999284744263, 0.3999999761581421, 1.0, 0.19999998807907104, 0.3999999761581421, 0.7999999523162842, 2.5999999046325684, 0.19999998807907104, 2.200000047683716, 0.3999999761581421, 1.7999999523162842, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [3.1999998092651367, 0.19999998807907104, 2.799999952316284, 0.7999999523162842, 1.0, 0.5999999642372131, 1.5999999046325684, 1.1999999284744263, 2.3999998569488525, 1.0, 1.0, 0.19999998807907104, 1.1999999284744263, 0.3999999761581421, 2.5999999046325684, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [2.799999952316284, 0.7999999523162842, 1.7999999523162842, 0.7999999523162842, 2.3999998569488525, 0.7999999523162842, 2.200000047683716, 1.399999976158142, 4.199999809265137, 1.7999999523162842, 0.5999999642372131, 0.19999998807907104, 1.0, 0.19999998807907104, 2.200000047683716, 0.3999999761581421] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [2.799999952316284, 2.0, 1.1999999284744263, 0.7999999523162842, 2.3999998569488525, 1.399999976158142, 0.3999999761581421, 1.399999976158142, 1.5999999046325684, 0.5999999642372131, 1.0, 0.5999999642372131, 0.3999999761581421, 0.7999999523162842, 1.7999999523162842, 0.19999998807907104] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.321+/- 0.1358 (max: 4.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.212+/- 0.1893 (max: 2.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.275+/- 0.2442 (max: 3.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 1.475+/- 0.2744 (max: 4.2) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 4.608+/- 0.2413 (max: 8.623) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 4.51+/- 0.3564 (max: 6.94) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 4.475+/- 0.4438 (max: 7.332) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 4.84+/- 0.467 (max: 8.623) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0002083+/- 0.0002083 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.000625+/- 0.000625 (max: 0.01) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.2 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.2 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.2 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 1.99 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 1.99 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 1.99 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-S5_SEED2 against population in Overcooked-CounterCircuit6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [2.200000047683716, 1.1999999284744263, 2.5999999046325684, 2.5999999046325684, 0.5999999642372131, 1.399999976158142, 1.0, 1.399999976158142, 1.5999999046325684, 2.799999952316284, 1.0, 0.0, 4.400000095367432, 1.399999976158142, 10.59999942779541, 4.799999713897705, 6.199999809265137, 3.0, 3.1999998092651367, 1.1999999284744263, 8.800000190734863, 7.799999713897705, 4.400000095367432, 1.7999999523162842, 2.5999999046325684, 2.799999952316284, 9.800000190734863, 5.0, 5.799999713897705, 2.0, 3.5999999046325684, 1.5999999046325684, 5.199999809265137, 3.799999952316284, 3.1999998092651367, 1.1999999284744263, 0.19999998807907104, 1.399999976158142, 1.0, 0.7999999523162842, 1.399999976158142, 1.1999999284744263, 3.799999952316284, 3.1999998092651367, 2.0, 1.7999999523162842, 1.7999999523162842, 0.3999999761581421] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [2.200000047683716, 1.1999999284744263, 1.0, 1.399999976158142, 4.400000095367432, 1.399999976158142, 3.1999998092651367, 1.1999999284744263, 2.5999999046325684, 2.799999952316284, 3.5999999046325684, 1.5999999046325684, 0.19999998807907104, 1.399999976158142, 3.799999952316284, 3.1999998092651367] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [2.5999999046325684, 2.5999999046325684, 1.5999999046325684, 2.799999952316284, 10.59999942779541, 4.799999713897705, 8.800000190734863, 7.799999713897705, 9.800000190734863, 5.0, 5.199999809265137, 3.799999952316284, 1.0, 0.7999999523162842, 2.0, 1.7999999523162842] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [0.5999999642372131, 1.399999976158142, 1.0, 0.0, 6.199999809265137, 3.0, 4.400000095367432, 1.7999999523162842, 5.799999713897705, 2.0, 3.1999998092651367, 1.1999999284744263, 1.399999976158142, 1.1999999284744263, 1.7999999523162842, 0.3999999761581421] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 2.95+/- 0.3522 (max: 10.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 2.212+/- 0.4617 (max: 6.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 2.2+/- 0.3 (max: 4.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 4.437+/- 0.8009 (max: 10.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 6.484+/- 0.35 (max: 12.8) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 5.596+/- 0.5967 (max: 9.25) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 5.967+/- 0.4242 (max: 8.34) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 7.889+/- 0.6428 (max: 12.8) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.004167+/- 0.001903 (max: 0.08) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.01125+/- 0.005313 (max: 0.08) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.8 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 3.919 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-S5_SEED2 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [1.1999999284744263, 7.0, 5.799999713897705, 16.399999618530273, 1.7999999523162842, 7.799999713897705, 2.0, 10.59999942779541, 4.0, 47.79999923706055, 1.5999999046325684, 10.199999809265137, 3.3999998569488525, 40.599998474121094, 5.400000095367432, 66.79999542236328, 2.5999999046325684, 17.19999885559082, 4.0, 27.399999618530273, 4.199999809265137, 16.19999885559082, 1.5999999046325684, 13.59999942779541, 2.200000047683716, 6.599999904632568, 4.799999713897705, 28.19999885559082, 4.799999713897705, 15.199999809265137, 2.200000047683716, 8.199999809265137, 5.199999809265137, 33.20000076293945, 2.5999999046325684, 11.0, 1.0, 7.199999809265137, 5.400000095367432, 29.0, 1.0, 16.0, 1.399999976158142, 11.0, 5.599999904632568, 31.599998474121094, 2.0, 7.599999904632568] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [1.1999999284744263, 7.0, 2.0, 10.59999942779541, 3.3999998569488525, 40.599998474121094, 4.0, 27.399999618530273, 2.200000047683716, 6.599999904632568, 2.200000047683716, 8.199999809265137, 1.0, 7.199999809265137, 1.399999976158142, 11.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [5.799999713897705, 16.399999618530273, 4.0, 47.79999923706055, 5.400000095367432, 66.79999542236328, 4.199999809265137, 16.19999885559082, 4.799999713897705, 28.19999885559082, 5.199999809265137, 33.20000076293945, 5.400000095367432, 29.0, 5.599999904632568, 31.599998474121094] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [1.7999999523162842, 7.799999713897705, 1.5999999046325684, 10.199999809265137, 2.5999999046325684, 17.19999885559082, 1.5999999046325684, 13.59999942779541, 4.799999713897705, 15.199999809265137, 2.5999999046325684, 11.0, 1.0, 16.0, 2.0, 7.599999904632568] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 11.71+/- 1.985 (max: 66.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 7.287+/- 1.46 (max: 17.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 8.5+/- 2.688 (max: 40.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 19.35+/- 4.669 (max: 66.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 12.37+/- 0.9426 (max: 29.58) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 10.65+/- 1.203 (max: 17.67) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 10.46+/- 1.462 (max: 26.15) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 15.99+/- 1.858 (max: 29.58) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.1371+/- 0.03183 (max: 0.91) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.065+/- 0.01895 (max: 0.2) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.08812+/- 0.04367 (max: 0.61) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.2581+/- 0.07598 (max: 0.91) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 1.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 1.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 1.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 4.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 4.359 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 4.359 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 4.359 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 8.998 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.01 | +------------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-S5_SEED2 against population in Overcooked-CrampedRoom6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [7.599999904632568, 9.800000190734863, 65.4000015258789, 61.39999771118164, 92.79999542236328, 89.5999984741211, 14.799999237060547, 14.799999237060547, 67.5999984741211, 64.5999984741211, 85.4000015258789, 85.79999542236328, 16.799999237060547, 16.19999885559082, 79.0, 90.0, 83.0, 82.5999984741211, 14.799999237060547, 17.399999618530273, 79.4000015258789, 79.0, 73.19999694824219, 66.4000015258789, 8.399999618530273, 9.0, 73.19999694824219, 64.79999542236328, 80.79999542236328, 78.0, 16.799999237060547, 14.59999942779541, 77.0, 74.5999984741211, 89.0, 82.0, 14.199999809265137, 13.59999942779541, 73.79999542236328, 70.79999542236328, 82.79999542236328, 80.4000015258789, 17.600000381469727, 16.19999885559082, 83.5999984741211, 79.19999694824219, 82.19999694824219, 86.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [7.599999904632568, 9.800000190734863, 14.799999237060547, 14.799999237060547, 16.799999237060547, 16.19999885559082, 14.799999237060547, 17.399999618530273, 8.399999618530273, 9.0, 16.799999237060547, 14.59999942779541, 14.199999809265137, 13.59999942779541, 17.600000381469727, 16.19999885559082] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [65.4000015258789, 61.39999771118164, 67.5999984741211, 64.5999984741211, 79.0, 90.0, 79.4000015258789, 79.0, 73.19999694824219, 64.79999542236328, 77.0, 74.5999984741211, 73.79999542236328, 70.79999542236328, 83.5999984741211, 79.19999694824219] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [92.79999542236328, 89.5999984741211, 85.4000015258789, 85.79999542236328, 83.0, 82.5999984741211, 73.19999694824219, 66.4000015258789, 80.79999542236328, 78.0, 89.0, 82.0, 82.79999542236328, 80.4000015258789, 82.19999694824219, 86.79999542236328] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 56.81+/- 4.538 (max: 92.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 82.55+/- 1.595 (max: 92.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 13.91+/- 0.8347 (max: 17.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 73.96+/- 1.959 (max: 90.0) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 21.84+/- 0.9106 (max: 32.97) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 25.4+/- 0.6193 (max: 29.1) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.69+/- 0.433 (max: 16.42) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 26.42+/- 0.745 (max: 32.97) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6762+/- 0.05639 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.965+/- 0.008165 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1337+/- 0.01741 (max: 0.23) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.93+/- 0.006831 (max: 0.99) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 7.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 66.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 7.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 61.4 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.65 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 19.9 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.65 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 20.95 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.87 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.88 | +------------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-S5_SEED3 against population in Overcooked-CoordRing6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [4.400000095367432, 3.5999999046325684, 13.59999942779541, 12.59999942779541, 16.19999885559082, 11.800000190734863, 16.19999885559082, 14.0, 25.19999885559082, 24.799999237060547, 12.59999942779541, 10.199999809265137, 12.0, 11.800000190734863, 15.199999809265137, 14.0, 14.59999942779541, 11.0, 9.399999618530273, 7.599999904632568, 17.0, 16.19999885559082, 11.800000190734863, 10.0, 16.799999237060547, 13.59999942779541, 17.0, 17.600000381469727, 15.59999942779541, 13.199999809265137, 9.800000190734863, 10.199999809265137, 19.19999885559082, 18.600000381469727, 24.0, 25.19999885559082, 2.5999999046325684, 2.5999999046325684, 16.600000381469727, 14.799999237060547, 20.0, 21.799999237060547, 6.0, 8.0, 12.199999809265137, 7.199999809265137, 15.399999618530273, 16.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [4.400000095367432, 3.5999999046325684, 16.19999885559082, 14.0, 12.0, 11.800000190734863, 9.399999618530273, 7.599999904632568, 16.799999237060547, 13.59999942779541, 9.800000190734863, 10.199999809265137, 2.5999999046325684, 2.5999999046325684, 6.0, 8.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [13.59999942779541, 12.59999942779541, 25.19999885559082, 24.799999237060547, 15.199999809265137, 14.0, 17.0, 16.19999885559082, 17.0, 17.600000381469727, 19.19999885559082, 18.600000381469727, 16.600000381469727, 14.799999237060547, 12.199999809265137, 7.199999809265137] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [16.19999885559082, 11.800000190734863, 12.59999942779541, 10.199999809265137, 14.59999942779541, 11.0, 11.800000190734863, 10.0, 15.59999942779541, 13.199999809265137, 24.0, 25.19999885559082, 20.0, 21.799999237060547, 15.399999618530273, 16.399999618530273] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CoordRing6_9 | 13.75+/- 0.8015 (max: 25.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 15.61+/- 1.204 (max: 25.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 9.287+/- 1.156 (max: 16.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 16.36+/- 1.113 (max: 25.2) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 11.34+/- 0.2731 (max: 15.56) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.22+/- 0.3897 (max: 15.13) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 10.89+/- 0.5991 (max: 14.34) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.9+/- 0.3934 (max: 15.56) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.09083+/- 0.01369 (max: 0.35) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.09562+/- 0.028 (max: 0.35) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.05437+/- 0.01369 (max: 0.17) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.1225+/- 0.02514 (max: 0.32) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 2.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 10.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 7.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 6.726 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 9.113 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.726 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 9.6 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-S5_SEED3 against population in Overcooked-ForcedCoord6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3999999761581421, 0.0, 0.19999998807907104, 0.0, 0.3999999761581421, 0.0, 0.19999998807907104, 0.0, 0.3999999761581421, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.5999999642372131, 0.0, 0.19999998807907104, 0.0, 1.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.19999998807907104, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.3999999761581421, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.5999999642372131, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.19999998807907104, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.3999999761581421, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.3999999761581421, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.1+/- 0.02793 (max: 1.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.1+/- 0.06583 (max: 1.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.1+/- 0.04472 (max: 0.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.1+/- 0.03162 (max: 0.4) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.7515+/- 0.1719 (max: 4.359) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.5718+/- 0.3264 (max: 4.359) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.7614+/- 0.3034 (max: 3.412) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.9212+/- 0.274 (max: 2.8) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-S5_SEED3 against population in Overcooked-CounterCircuit6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [2.200000047683716, 0.5999999642372131, 3.5999999046325684, 3.0, 3.0, 2.5999999046325684, 0.7999999523162842, 2.0, 2.5999999046325684, 2.0, 0.3999999761581421, 0.0, 3.1999998092651367, 1.7999999523162842, 10.399999618530273, 4.400000095367432, 8.0, 2.200000047683716, 3.3999998569488525, 2.3999998569488525, 11.800000190734863, 8.0, 3.1999998092651367, 0.19999998807907104, 4.400000095367432, 1.1999999284744263, 13.799999237060547, 9.199999809265137, 7.399999618530273, 1.399999976158142, 2.5999999046325684, 1.1999999284744263, 6.0, 4.799999713897705, 4.599999904632568, 0.7999999523162842, 0.7999999523162842, 0.5999999642372131, 4.0, 4.599999904632568, 1.0, 0.0, 1.5999999046325684, 0.19999998807907104, 2.5999999046325684, 2.0, 2.5999999046325684, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [2.200000047683716, 0.5999999642372131, 0.7999999523162842, 2.0, 3.1999998092651367, 1.7999999523162842, 3.3999998569488525, 2.3999998569488525, 4.400000095367432, 1.1999999284744263, 2.5999999046325684, 1.1999999284744263, 0.7999999523162842, 0.5999999642372131, 1.5999999046325684, 0.19999998807907104] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [3.5999999046325684, 3.0, 2.5999999046325684, 2.0, 10.399999618530273, 4.400000095367432, 11.800000190734863, 8.0, 13.799999237060547, 9.199999809265137, 6.0, 4.799999713897705, 4.0, 4.599999904632568, 2.5999999046325684, 2.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [3.0, 2.5999999046325684, 0.3999999761581421, 0.0, 8.0, 2.200000047683716, 3.1999998092651367, 0.19999998807907104, 7.399999618530273, 1.399999976158142, 4.599999904632568, 0.7999999523162842, 1.0, 0.0, 2.5999999046325684, 0.0] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 3.317+/- 0.4574 (max: 13.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 2.337+/- 0.6255 (max: 8.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 1.812+/- 0.2935 (max: 4.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 5.8+/- 0.9293 (max: 13.8) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 6.457+/- 0.437 (max: 12.68) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 5.064+/- 0.8473 (max: 10.2) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 5.37+/- 0.4351 (max: 8.285) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 8.936+/- 0.5195 (max: 12.68) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.006458+/- 0.00257 (max: 0.08) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.01812+/- 0.006905 (max: 0.08) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 2.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 6.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating PLR_CNN-S5_SEED3 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 6.599999904632568, 0.0, 15.59999942779541, 0.0, 5.400000095367432, 0.0, 7.599999904632568, 0.0, 44.599998474121094, 0.0, 2.5999999046325684, 0.0, 40.79999923706055, 0.0, 62.19999694824219, 0.0, 17.799999237060547, 0.0, 20.600000381469727, 0.0, 11.59999942779541, 0.0, 6.599999904632568, 0.0, 5.400000095367432, 0.0, 22.600000381469727, 0.0, 11.0, 0.0, 5.599999904632568, 0.0, 31.599998474121094, 0.0, 3.1999998092651367, 0.0, 6.599999904632568, 0.0, 18.600000381469727, 0.0, 10.199999809265137, 0.19999998807907104, 6.799999713897705, 0.0, 27.19999885559082, 0.0, 2.200000047683716] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 6.599999904632568, 0.0, 7.599999904632568, 0.0, 40.79999923706055, 0.0, 20.600000381469727, 0.0, 5.400000095367432, 0.0, 5.599999904632568, 0.0, 6.599999904632568, 0.19999998807907104, 6.799999713897705] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 15.59999942779541, 0.0, 44.599998474121094, 0.0, 62.19999694824219, 0.0, 11.59999942779541, 0.0, 22.600000381469727, 0.0, 31.599998474121094, 0.0, 18.600000381469727, 0.0, 27.19999885559082] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 5.400000095367432, 0.0, 2.5999999046325684, 0.0, 17.799999237060547, 0.0, 6.599999904632568, 0.0, 11.0, 0.0, 3.1999998092651367, 0.0, 10.199999809265137, 0.0, 2.200000047683716] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 8.192+/- 1.951 (max: 62.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 3.687+/- 1.321 (max: 17.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 6.262+/- 2.672 (max: 40.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 14.62+/- 4.747 (max: 62.2) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 7.613+/- 1.283 (max: 30.45) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 5.605+/- 1.624 (max: 18.52) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 6.955+/- 2.099 (max: 29.11) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 10.28+/- 2.768 (max: 30.45) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.1017+/- 0.02945 (max: 0.84) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.03062+/- 0.01487 (max: 0.21) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.065+/- 0.04174 (max: 0.65) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.2094+/- 0.07065 (max: 0.84) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PLR_CNN-S5_SEED3 against population in Overcooked-CrampedRoom6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [7.799999713897705, 7.199999809265137, 61.79999923706055, 61.0, 69.0, 75.0, 16.19999885559082, 16.799999237060547, 70.5999984741211, 71.19999694824219, 79.5999984741211, 71.0, 17.600000381469727, 16.799999237060547, 83.4000015258789, 92.0, 71.4000015258789, 67.5999984741211, 16.799999237060547, 19.399999618530273, 77.5999984741211, 77.4000015258789, 60.599998474121094, 52.19999694824219, 7.0, 7.199999809265137, 66.79999542236328, 70.0, 69.0, 60.79999923706055, 16.600000381469727, 17.19999885559082, 75.19999694824219, 78.79999542236328, 67.79999542236328, 64.19999694824219, 16.19999885559082, 16.399999618530273, 74.0, 70.79999542236328, 72.79999542236328, 65.5999984741211, 16.0, 17.600000381469727, 79.0, 80.79999542236328, 73.5999984741211, 65.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [7.799999713897705, 7.199999809265137, 16.19999885559082, 16.799999237060547, 17.600000381469727, 16.799999237060547, 16.799999237060547, 19.399999618530273, 7.0, 7.199999809265137, 16.600000381469727, 17.19999885559082, 16.19999885559082, 16.399999618530273, 16.0, 17.600000381469727] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [61.79999923706055, 61.0, 70.5999984741211, 71.19999694824219, 83.4000015258789, 92.0, 77.5999984741211, 77.4000015258789, 66.79999542236328, 70.0, 75.19999694824219, 78.79999542236328, 74.0, 70.79999542236328, 79.0, 80.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [69.0, 75.0, 79.5999984741211, 71.0, 71.4000015258789, 67.5999984741211, 60.599998474121094, 52.19999694824219, 69.0, 60.79999923706055, 67.79999542236328, 64.19999694824219, 72.79999542236328, 65.5999984741211, 73.5999984741211, 65.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 52.26+/- 4.014 (max: 92.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 67.82+/- 1.64 (max: 79.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 14.55+/- 1.099 (max: 19.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 74.4+/- 1.985 (max: 92.0) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 20.95+/- 0.9244 (max: 30.37) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 27.35+/- 0.4303 (max: 30.37) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 12.93+/- 0.4385 (max: 15.68) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 22.56+/- 0.6814 (max: 27.47) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6635+/- 0.05573 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9056+/- 0.01599 (max: 0.97) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.13+/- 0.01839 (max: 0.23) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.955+/- 0.007246 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 7.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 52.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 7.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 61.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.34 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 24.72 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.34 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 18.78 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.74 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9 | +------------------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-S5_SEED1 against population in Overcooked-CoordRing6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [2.3999998569488525, 0.0, 0.5999999642372131, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7999999523162842, 0.0, 0.0, 0.0, 0.0, 0.0, 7.799999713897705, 0.0, 4.400000095367432, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.0, 0.0, 1.5999999046325684, 0.0, 10.59999942779541, 0.0, 0.0, 0.0, 1.1999999284744263, 0.0, 4.0, 0.0, 0.0, 0.0, 0.7999999523162842, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [2.3999998569488525, 0.0, 0.0, 0.0, 0.7999999523162842, 0.0, 7.799999713897705, 0.0, 0.19999998807907104, 0.0, 1.5999999046325684, 0.0, 1.1999999284744263, 0.0, 0.7999999523162842, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [0.5999999642372131, 0.0, 0.0, 0.0, 0.0, 0.0, 4.400000095367432, 0.0, 0.0, 0.0, 10.59999942779541, 0.0, 4.0, 0.0, 0.19999998807907104, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +--------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 0.7208+/- 0.2941 (max: 10.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 0.925+/- 0.4926 (max: 7.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 1.237+/- 0.7177 (max: 10.6) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 1.474+/- 0.4333 (max: 11.47) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 2.291+/- 0.7855 (max: 10.16) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 2.132+/- 0.9643 (max: 11.47) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.001458+/- 0.0009411 (max: 0.04) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.00375+/- 0.00272 (max: 0.04) | +| eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +--------------------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-S5_SEED1 against population in Overcooked-ForcedCoord6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +----------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-S5_SEED1 against population in Overcooked-CounterCircuit6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.004167+/- 0.004167 (max: 0.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.0125+/- 0.0125 (max: 0.2) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.04146+/- 0.04146 (max: 1.99) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.1244+/- 0.1244 (max: 1.99) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-S5_SEED1 against population in Overcooked-AsymmAdvantages6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 1.0, 0.0, 8.399999618530273, 0.0, 0.0, 0.0, 4.799999713897705, 0.0, 43.0, 0.0, 0.7999999523162842, 0.0, 29.0, 0.0, 58.0, 0.0, 7.799999713897705, 0.0, 15.399999618530273, 0.0, 8.59999942779541, 0.0, 0.3999999761581421, 0.0, 0.7999999523162842, 0.0, 17.399999618530273, 0.0, 0.5999999642372131, 0.0, 1.5999999046325684, 0.0, 26.599998474121094, 0.0, 0.0, 0.0, 2.200000047683716, 0.0, 13.799999237060547, 0.0, 2.0, 0.0, 2.5999999046325684, 0.0, 16.799999237060547, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 1.0, 0.0, 4.799999713897705, 0.0, 29.0, 0.0, 15.399999618530273, 0.0, 0.7999999523162842, 0.0, 1.5999999046325684, 0.0, 2.200000047683716, 0.0, 2.5999999046325684] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 8.399999618530273, 0.0, 43.0, 0.0, 58.0, 0.0, 8.59999942779541, 0.0, 17.399999618530273, 0.0, 26.599998474121094, 0.0, 13.799999237060547, 0.0, 16.799999237060547] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 0.0, 0.0, 0.7999999523162842, 0.0, 7.799999713897705, 0.0, 0.3999999761581421, 0.0, 0.5999999642372131, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 5.45+/- 1.71 (max: 58.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.725+/- 0.4899 (max: 7.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 3.587+/- 1.948 (max: 29.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 12.04+/- 4.341 (max: 58.0) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 5.772+/- 1.305 (max: 34.35) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 1.618+/- 0.7191 (max: 9.755) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 5.192+/- 2.033 (max: 29.98) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 10.5+/- 2.948 (max: 34.35) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.06833+/- 0.0244 (max: 0.77) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.04+/- 0.0277 (max: 0.39) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.165+/- 0.06196 (max: 0.77) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-S5_SEED1 against population in Overcooked-CrampedRoom6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [9.0, 0.0, 97.79999542236328, 0.0, 65.0, 0.0, 13.59999942779541, 0.0, 75.19999694824219, 0.0, 17.19999885559082, 0.0, 14.59999942779541, 0.0, 94.19999694824219, 0.0, 6.399999618530273, 0.0, 14.59999942779541, 0.0, 79.4000015258789, 0.0, 14.399999618530273, 0.0, 6.399999618530273, 0.0, 94.0, 0.0, 21.19999885559082, 0.0, 13.59999942779541, 0.0, 58.599998474121094, 0.0, 11.399999618530273, 0.0, 14.0, 0.0, 107.79999542236328, 0.0, 13.0, 0.0, 24.600000381469727, 0.0, 102.39999389648438, 0.0, 9.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [9.0, 0.0, 13.59999942779541, 0.0, 14.59999942779541, 0.0, 14.59999942779541, 0.0, 6.399999618530273, 0.0, 13.59999942779541, 0.0, 14.0, 0.0, 24.600000381469727, 0.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [97.79999542236328, 0.0, 75.19999694824219, 0.0, 94.19999694824219, 0.0, 79.4000015258789, 0.0, 94.0, 0.0, 58.599998474121094, 0.0, 107.79999542236328, 0.0, 102.39999389648438, 0.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [65.0, 0.0, 17.19999885559082, 0.0, 6.399999618530273, 0.0, 14.399999618530273, 0.0, 21.19999885559082, 0.0, 11.399999618530273, 0.0, 13.0, 0.0, 9.0, 0.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 20.36+/- 4.81 (max: 107.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 9.85+/- 4.106 (max: 65.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 6.9+/- 1.997 (max: 24.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 44.34+/- 11.78 (max: 107.8) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 11.5+/- 2.081 (max: 49.84) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 9.037+/- 2.414 (max: 23.22) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 6.978+/- 1.859 (max: 17.17) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 18.49+/- 5.137 (max: 49.84) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.21+/- 0.04861 (max: 0.98) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.1262+/- 0.05903 (max: 0.95) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.07+/- 0.02566 (max: 0.35) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.4337+/- 0.1138 (max: 0.98) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-S5_SEED2 against population in Overcooked-CoordRing6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [0.3999999761581421, 3.1999998092651367, 0.0, 0.19999998807907104, 0.0, 0.0, 1.5999999046325684, 8.800000190734863, 1.399999976158142, 2.799999952316284, 0.0, 0.0, 1.5999999046325684, 5.199999809265137, 0.0, 0.0, 0.0, 0.0, 0.5999999642372131, 1.0, 0.3999999761581421, 1.0, 0.0, 0.0, 1.0, 3.3999998569488525, 0.0, 0.0, 0.0, 0.0, 0.7999999523162842, 2.3999998569488525, 3.799999952316284, 4.799999713897705, 0.19999998807907104, 0.0, 0.3999999761581421, 0.0, 1.399999976158142, 3.0, 0.0, 0.0, 0.19999998807907104, 0.19999998807907104, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [0.3999999761581421, 3.1999998092651367, 1.5999999046325684, 8.800000190734863, 1.5999999046325684, 5.199999809265137, 0.5999999642372131, 1.0, 1.0, 3.3999998569488525, 0.7999999523162842, 2.3999998569488525, 0.3999999761581421, 0.0, 0.19999998807907104, 0.19999998807907104] +k eval/a1:test_return:Overcooked-CoordRing6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [0.0, 0.19999998807907104, 1.399999976158142, 2.799999952316284, 0.0, 0.0, 0.3999999761581421, 1.0, 0.0, 0.0, 3.799999952316284, 4.799999713897705, 1.399999976158142, 3.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 1.037+/- 0.2565 (max: 8.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 0.0125+/- 0.0125 (max: 0.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 1.925+/- 0.5819 (max: 8.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 1.175+/- 0.3945 (max: 4.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 2.755+/- 0.4724 (max: 11.77) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 0.1244+/- 0.1244 (max: 1.99) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 4.941+/- 0.7438 (max: 11.77) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 3.201+/- 0.858 (max: 9.432) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.001875+/- 0.00114 (max: 0.05) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.00375+/- 0.003146 (max: 0.05) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.001875+/- 0.00136 (max: 0.02) | +| eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-S5_SEED2 against population in Overcooked-ForcedCoord6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +----------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-S5_SEED2 against population in Overcooked-CounterCircuit6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 1.7999999523162842, 0.0, 0.3999999761581421, 0.0, 0.0, 0.19999998807907104, 1.7999999523162842, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.3999999761581421, 0.19999998807907104, 1.5999999046325684, 0.19999998807907104, 14.399999618530273, 7.199999809265137, 0.5999999642372131, 0.0, 1.5999999046325684, 0.5999999642372131, 4.599999904632568, 1.7999999523162842, 2.799999952316284, 0.7999999523162842, 0.19999998807907104, 0.0, 3.3999998569488525, 0.5999999642372131, 0.5999999642372131, 0.3999999761581421, 0.19999998807907104, 0.0, 2.200000047683716, 0.0, 0.0, 0.0, 1.0, 0.3999999761581421, 6.599999904632568, 0.7999999523162842, 0.3999999761581421, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.19999998807907104, 0.19999998807907104, 0.0, 1.5999999046325684, 0.19999998807907104, 1.5999999046325684, 0.5999999642372131, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 1.0, 0.3999999761581421] +k eval/a1:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [1.7999999523162842, 0.0, 1.7999999523162842, 0.0, 0.0, 0.0, 14.399999618530273, 7.199999809265137, 4.599999904632568, 1.7999999523162842, 3.3999998569488525, 0.5999999642372131, 2.200000047683716, 0.0, 6.599999904632568, 0.7999999523162842] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [0.3999999761581421, 0.0, 0.0, 0.0, 0.3999999761581421, 0.19999998807907104, 0.5999999642372131, 0.0, 2.799999952316284, 0.7999999523162842, 0.5999999642372131, 0.3999999761581421, 0.0, 0.0, 0.3999999761581421, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +--------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 1.208+/- 0.3632 (max: 14.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.4125+/- 0.1727 (max: 2.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.3875+/- 0.136 (max: 1.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 2.825+/- 0.9647 (max: 14.4) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 2.87+/- 0.4278 (max: 9.82) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 1.93+/- 0.5105 (max: 6.94) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 1.961+/- 0.4815 (max: 5.426) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 4.72+/- 0.9409 (max: 9.82) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0004167+/- 0.0002915 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +--------------------------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-S5_SEED2 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 8.399999618530273, 0.0, 21.799999237060547, 0.0, 1.7999999523162842, 0.0, 8.800000190734863, 0.0, 28.0, 0.0, 1.399999976158142, 0.0, 36.39999771118164, 0.0, 68.19999694824219, 0.0, 13.799999237060547, 0.0, 22.600000381469727, 0.0, 7.799999713897705, 0.0, 7.0, 0.0, 10.59999942779541, 0.0, 24.0, 0.0, 20.0, 0.0, 9.800000190734863, 0.0, 26.599998474121094, 0.0, 1.0, 0.0, 9.800000190734863, 0.0, 12.59999942779541, 0.0, 11.0, 0.0, 15.59999942779541, 0.0, 27.399999618530273, 0.0, 1.5999999046325684] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 8.399999618530273, 0.0, 8.800000190734863, 0.0, 36.39999771118164, 0.0, 22.600000381469727, 0.0, 10.59999942779541, 0.0, 9.800000190734863, 0.0, 9.800000190734863, 0.0, 15.59999942779541] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 21.799999237060547, 0.0, 28.0, 0.0, 68.19999694824219, 0.0, 7.799999713897705, 0.0, 24.0, 0.0, 26.599998474121094, 0.0, 12.59999942779541, 0.0, 27.399999618530273] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 1.7999999523162842, 0.0, 1.399999976158142, 0.0, 13.799999237060547, 0.0, 7.0, 0.0, 20.0, 0.0, 1.0, 0.0, 11.0, 0.0, 1.5999999046325684] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +-------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 8.25+/- 1.906 (max: 68.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 3.6+/- 1.53 (max: 20.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 7.625+/- 2.583 (max: 36.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 13.52+/- 4.672 (max: 68.2) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 6.361+/- 1.15 (max: 34.65) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 3.48+/- 0.977 (max: 9.95) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 6.666+/- 1.93 (max: 24.88) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 8.937+/- 2.599 (max: 34.65) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.07937+/- 0.02494 (max: 0.83) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.003125+/- 0.003125 (max: 0.05) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.06187+/- 0.03528 (max: 0.51) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1731+/- 0.05997 (max: 0.83) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +-------------------------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-S5_SEED2 against population in Overcooked-CrampedRoom6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [0.7999999523162842, 0.7999999523162842, 4.799999713897705, 3.0, 2.0, 0.7999999523162842, 3.3999998569488525, 3.0, 4.799999713897705, 2.5999999046325684, 1.5999999046325684, 0.0, 3.1999998092651367, 3.1999998092651367, 3.1999998092651367, 3.1999998092651367, 1.399999976158142, 0.3999999761581421, 2.200000047683716, 2.0, 3.1999998092651367, 2.200000047683716, 1.0, 0.19999998807907104, 1.399999976158142, 1.399999976158142, 7.199999809265137, 4.400000095367432, 0.7999999523162842, 0.3999999761581421, 2.200000047683716, 3.1999998092651367, 1.5999999046325684, 1.399999976158142, 1.399999976158142, 0.3999999761581421, 3.1999998092651367, 3.1999998092651367, 4.799999713897705, 4.599999904632568, 1.1999999284744263, 0.3999999761581421, 4.0, 2.799999952316284, 6.0, 3.799999952316284, 1.1999999284744263, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [0.7999999523162842, 0.7999999523162842, 3.3999998569488525, 3.0, 3.1999998092651367, 3.1999998092651367, 2.200000047683716, 2.0, 1.399999976158142, 1.399999976158142, 2.200000047683716, 3.1999998092651367, 3.1999998092651367, 3.1999998092651367, 4.0, 2.799999952316284] +k eval/a1:test_return:Overcooked-CrampedRoom6_9, v [0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:low, v [0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [4.799999713897705, 3.0, 4.799999713897705, 2.5999999046325684, 3.1999998092651367, 3.1999998092651367, 3.1999998092651367, 2.200000047683716, 7.199999809265137, 4.400000095367432, 1.5999999046325684, 1.399999976158142, 4.799999713897705, 4.599999904632568, 6.0, 3.799999952316284] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:mid, v [0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [2.0, 0.7999999523162842, 1.5999999046325684, 0.0, 1.399999976158142, 0.3999999761581421, 1.0, 0.19999998807907104, 0.7999999523162842, 0.3999999761581421, 1.399999976158142, 0.3999999761581421, 1.1999999284744263, 0.3999999761581421, 1.1999999284744263, 0.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:high, v [0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104] +---------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 2.375+/- 0.238 (max: 7.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 0.825+/- 0.1504 (max: 2.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 2.5+/- 0.2456 (max: 4.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 3.8+/- 0.3912 (max: 7.2) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 6.011+/- 0.3493 (max: 10.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 3.532+/- 0.4474 (max: 6.0) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 6.68+/- 0.3846 (max: 9.798) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 7.82+/- 0.3758 (max: 10.4) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.003333+/- 0.001127 (max: 0.04) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.005+/- 0.002582 (max: 0.04) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.005+/- 0.002041 (max: 0.02) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.1+/- 0.01459 (max: 0.2) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.1+/- 0.02582 (max: 0.2) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.1+/- 0.02582 (max: 0.2) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.1+/- 0.02582 (max: 0.2) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.995+/- 0.1451 (max: 1.99) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.995+/- 0.2569 (max: 1.99) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.995+/- 0.2569 (max: 1.99) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.995+/- 0.2569 (max: 1.99) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 0.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 1.4 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 3.919 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 5.103 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +---------------------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-S5_SEED3 against population in Overcooked-CoordRing6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [4.799999713897705, 5.199999809265137, 17.799999237060547, 17.19999885559082, 23.399999618530273, 26.0, 13.59999942779541, 15.799999237060547, 27.19999885559082, 25.399999618530273, 22.19999885559082, 23.19999885559082, 13.0, 11.59999942779541, 25.599998474121094, 23.0, 22.399999618530273, 23.0, 13.59999942779541, 12.399999618530273, 16.19999885559082, 15.0, 17.19999885559082, 18.399999618530273, 14.59999942779541, 17.19999885559082, 20.399999618530273, 21.0, 21.19999885559082, 23.0, 13.59999942779541, 14.799999237060547, 24.399999618530273, 22.799999237060547, 33.79999923706055, 31.399999618530273, 3.0, 2.799999952316284, 17.399999618530273, 15.0, 27.799999237060547, 29.599998474121094, 10.59999942779541, 12.799999237060547, 12.799999237060547, 12.799999237060547, 19.19999885559082, 19.799999237060547] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [4.799999713897705, 5.199999809265137, 13.59999942779541, 15.799999237060547, 13.0, 11.59999942779541, 13.59999942779541, 12.399999618530273, 14.59999942779541, 17.19999885559082, 13.59999942779541, 14.799999237060547, 3.0, 2.799999952316284, 10.59999942779541, 12.799999237060547] +k eval/a1:test_return:Overcooked-CoordRing6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [17.799999237060547, 17.19999885559082, 27.19999885559082, 25.399999618530273, 25.599998474121094, 23.0, 16.19999885559082, 15.0, 20.399999618530273, 21.0, 24.399999618530273, 22.799999237060547, 17.399999618530273, 15.0, 12.799999237060547, 12.799999237060547] +k eval/a1:test_return:Overcooked-CoordRing6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [23.399999618530273, 26.0, 22.19999885559082, 23.19999885559082, 22.399999618530273, 23.0, 17.19999885559082, 18.399999618530273, 21.19999885559082, 23.0, 33.79999923706055, 31.399999618530273, 27.799999237060547, 29.599998474121094, 19.19999885559082, 19.799999237060547] +k eval/a1:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 18.23+/- 1.013 (max: 33.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 23.85+/- 1.182 (max: 33.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 11.21+/- 1.157 (max: 17.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 19.62+/- 1.182 (max: 27.2) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 12.0+/- 0.2966 (max: 16.63) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.37+/- 0.4493 (max: 14.28) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 11.66+/- 0.6011 (max: 14.66) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 12.96+/- 0.4105 (max: 16.63) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.1858+/- 0.02108 (max: 0.66) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.265+/- 0.042 (max: 0.66) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0775+/- 0.01315 (max: 0.16) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.215+/- 0.03134 (max: 0.42) | +| eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 2.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 17.2 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 12.8 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 6.94 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 8.818 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.94 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 10.4 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.11 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.02 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------- +Evaluating PAIRED_CNN-S5_SEED3 against population in Overcooked-ForcedCoord6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [1.5999999046325684, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 0.5999999642372131, 0.0, 1.7999999523162842, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.19999998807907104, 1.5999999046325684, 0.0, 0.19999998807907104, 0.19999998807907104, 1.1999999284744263, 0.0, 2.3999998569488525, 0.0, 1.0, 0.0, 1.7999999523162842, 0.0, 2.5999999046325684, 0.0, 1.0, 0.0, 1.1999999284744263, 0.0, 0.3999999761581421, 0.0, 0.19999998807907104, 0.19999998807907104, 0.3999999761581421, 0.0, 1.0, 0.19999998807907104, 1.0, 0.0, 2.0, 0.0, 3.1999998092651367, 0.0, 2.3999998569488525, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [1.5999999046325684, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.19999998807907104, 1.1999999284744263, 0.0, 1.7999999523162842, 0.0, 1.1999999284744263, 0.0, 0.3999999761581421, 0.0, 2.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [1.1999999284744263, 0.0, 1.7999999523162842, 0.0, 1.5999999046325684, 0.0, 2.3999998569488525, 0.0, 2.5999999046325684, 0.0, 0.3999999761581421, 0.0, 1.0, 0.19999998807907104, 3.1999998092651367, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [1.1999999284744263, 0.0, 0.5999999642372131, 0.0, 0.19999998807907104, 0.19999998807907104, 1.0, 0.0, 1.0, 0.0, 0.19999998807907104, 0.19999998807907104, 1.0, 0.0, 2.3999998569488525, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.6667+/- 0.1231 (max: 3.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.5+/- 0.1673 (max: 2.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.6+/- 0.1807 (max: 2.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.9+/- 0.2757 (max: 3.2) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 2.491+/- 0.3565 (max: 7.332) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 2.231+/- 0.5387 (max: 6.499) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 2.391+/- 0.6017 (max: 6.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 2.85+/- 0.7269 (max: 7.332) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-S5_SEED3 against population in Overcooked-CounterCircuit6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [9.59999942779541, 6.599999904632568, 15.0, 8.399999618530273, 8.59999942779541, 8.59999942779541, 5.0, 5.400000095367432, 8.0, 9.0, 4.0, 1.0, 5.0, 3.3999998569488525, 11.800000190734863, 11.800000190734863, 10.399999618530273, 6.0, 8.800000190734863, 6.0, 19.19999885559082, 15.59999942779541, 10.800000190734863, 7.799999713897705, 7.799999713897705, 7.599999904632568, 16.799999237060547, 13.399999618530273, 8.800000190734863, 2.3999998569488525, 11.199999809265137, 8.59999942779541, 16.399999618530273, 9.199999809265137, 8.399999618530273, 6.199999809265137, 9.199999809265137, 5.599999904632568, 9.800000190734863, 9.800000190734863, 8.0, 5.599999904632568, 12.0, 7.199999809265137, 13.399999618530273, 13.399999618530273, 9.59999942779541, 4.799999713897705] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [9.59999942779541, 6.599999904632568, 5.0, 5.400000095367432, 5.0, 3.3999998569488525, 8.800000190734863, 6.0, 7.799999713897705, 7.599999904632568, 11.199999809265137, 8.59999942779541, 9.199999809265137, 5.599999904632568, 12.0, 7.199999809265137] +k eval/a1:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [15.0, 8.399999618530273, 8.0, 9.0, 11.800000190734863, 11.800000190734863, 19.19999885559082, 15.59999942779541, 16.799999237060547, 13.399999618530273, 16.399999618530273, 9.199999809265137, 9.800000190734863, 9.800000190734863, 13.399999618530273, 13.399999618530273] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [8.59999942779541, 8.59999942779541, 4.0, 1.0, 10.399999618530273, 6.0, 10.800000190734863, 7.799999713897705, 8.800000190734863, 2.3999998569488525, 8.399999618530273, 6.199999809265137, 8.0, 5.599999904632568, 9.59999942779541, 4.799999713897705] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 8.979+/- 0.5517 (max: 19.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 6.937+/- 0.7049 (max: 10.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 7.437+/- 0.5939 (max: 12.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 12.56+/- 0.848 (max: 19.2) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 11.04+/- 0.3655 (max: 17.45) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 9.48+/- 0.4317 (max: 11.48) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 10.18+/- 0.3669 (max: 12.43) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 13.46+/- 0.577 (max: 17.45) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.04729+/- 0.008826 (max: 0.22) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.01187+/- 0.003561 (max: 0.04) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.02187+/- 0.006783 (max: 0.07) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.1081+/- 0.01733 (max: 0.22) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 1.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 1.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 3.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 8.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 5.196 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 5.196 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 7.513 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 9.837 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-S5_SEED3 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.3999999761581421, 2.3999998569488525, 0.3999999761581421, 7.0, 0.5999999642372131, 0.5999999642372131, 0.3999999761581421, 5.599999904632568, 0.3999999761581421, 41.599998474121094, 0.3999999761581421, 0.5999999642372131, 0.19999998807907104, 33.599998474121094, 0.19999998807907104, 55.79999923706055, 0.5999999642372131, 11.59999942779541, 0.19999998807907104, 17.19999885559082, 0.3999999761581421, 7.599999904632568, 1.0, 1.0, 0.5999999642372131, 2.200000047683716, 0.19999998807907104, 17.19999885559082, 0.5999999642372131, 2.200000047683716, 0.5999999642372131, 2.799999952316284, 0.19999998807907104, 27.799999237060547, 0.19999998807907104, 0.19999998807907104, 0.3999999761581421, 4.599999904632568, 0.7999999523162842, 15.199999809265137, 0.0, 2.200000047683716, 0.7999999523162842, 6.199999809265137, 0.3999999761581421, 17.799999237060547, 1.1999999284744263, 1.1999999284744263] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.3999999761581421, 2.3999998569488525, 0.3999999761581421, 5.599999904632568, 0.19999998807907104, 33.599998474121094, 0.19999998807907104, 17.19999885559082, 0.5999999642372131, 2.200000047683716, 0.5999999642372131, 2.799999952316284, 0.3999999761581421, 4.599999904632568, 0.7999999523162842, 6.199999809265137] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.3999999761581421, 7.0, 0.3999999761581421, 41.599998474121094, 0.19999998807907104, 55.79999923706055, 0.3999999761581421, 7.599999904632568, 0.19999998807907104, 17.19999885559082, 0.19999998807907104, 27.799999237060547, 0.7999999523162842, 15.199999809265137, 0.3999999761581421, 17.799999237060547] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.5999999642372131, 0.5999999642372131, 0.3999999761581421, 0.5999999642372131, 0.5999999642372131, 11.59999942779541, 1.0, 1.0, 0.5999999642372131, 2.200000047683716, 0.19999998807907104, 0.19999998807907104, 0.0, 2.200000047683716, 1.1999999284744263, 1.1999999284744263] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 6.154+/- 1.689 (max: 55.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 1.512+/- 0.6909 (max: 11.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 4.887+/- 2.193 (max: 33.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 12.06+/- 4.204 (max: 55.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 7.68+/- 1.121 (max: 33.14) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 4.076+/- 0.588 (max: 10.65) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 7.48+/- 1.785 (max: 28.83) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 11.49+/- 2.54 (max: 33.14) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.07417+/- 0.02551 (max: 0.73) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.00125+/- 0.00125 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.05437+/- 0.03504 (max: 0.5) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1669+/- 0.06274 (max: 0.73) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.2 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 1.99 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PAIRED_CNN-S5_SEED3 against population in Overcooked-CrampedRoom6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [11.800000190734863, 8.59999942779541, 57.599998474121094, 56.39999771118164, 68.79999542236328, 63.0, 18.0, 15.0, 58.0, 53.79999923706055, 55.79999923706055, 51.79999923706055, 15.799999237060547, 16.0, 56.39999771118164, 60.0, 58.79999923706055, 55.19999694824219, 17.399999618530273, 19.600000381469727, 62.599998474121094, 59.39999771118164, 47.0, 47.20000076293945, 11.59999942779541, 11.59999942779541, 53.0, 46.39999771118164, 42.0, 46.79999923706055, 20.19999885559082, 16.0, 56.39999771118164, 58.0, 58.79999923706055, 51.79999923706055, 14.799999237060547, 16.399999618530273, 60.39999771118164, 61.79999923706055, 59.79999923706055, 55.599998474121094, 19.799999237060547, 18.600000381469727, 64.0, 58.0, 47.20000076293945, 44.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [11.800000190734863, 8.59999942779541, 18.0, 15.0, 15.799999237060547, 16.0, 17.399999618530273, 19.600000381469727, 11.59999942779541, 11.59999942779541, 20.19999885559082, 16.0, 14.799999237060547, 16.399999618530273, 19.799999237060547, 18.600000381469727] +k eval/a1:test_return:Overcooked-CrampedRoom6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [57.599998474121094, 56.39999771118164, 58.0, 53.79999923706055, 56.39999771118164, 60.0, 62.599998474121094, 59.39999771118164, 53.0, 46.39999771118164, 56.39999771118164, 58.0, 60.39999771118164, 61.79999923706055, 64.0, 58.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [68.79999542236328, 63.0, 55.79999923706055, 51.79999923706055, 58.79999923706055, 55.19999694824219, 47.0, 47.20000076293945, 42.0, 46.79999923706055, 58.79999923706055, 51.79999923706055, 59.79999923706055, 55.599998474121094, 47.20000076293945, 44.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 42.23+/- 2.848 (max: 68.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 53.35+/- 1.857 (max: 68.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 15.7+/- 0.8418 (max: 20.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 57.64+/- 1.053 (max: 64.0) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 24.3+/- 1.199 (max: 35.03) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 29.12+/- 0.5422 (max: 32.41) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.31+/- 0.2789 (max: 15.1) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 30.45+/- 1.004 (max: 35.03) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.5702+/- 0.04507 (max: 0.94) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.7775+/- 0.01811 (max: 0.92) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1431+/- 0.01653 (max: 0.25) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.79+/- 0.01579 (max: 0.94) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 8.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 42.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 8.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 46.4 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 11.05 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 25.69 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 11.05 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 20.25 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.65 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.66 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-S5_SEED1 against population in Overcooked-CoordRing6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [3.0, 3.0, 12.799999237060547, 9.0, 11.199999809265137, 12.0, 10.399999618530273, 8.0, 17.600000381469727, 21.19999885559082, 12.399999618530273, 13.799999237060547, 9.0, 11.0, 15.399999618530273, 15.799999237060547, 15.199999809265137, 17.0, 13.199999809265137, 10.399999618530273, 13.59999942779541, 12.59999942779541, 11.800000190734863, 11.399999618530273, 12.0, 8.800000190734863, 8.59999942779541, 8.0, 15.399999618530273, 17.0, 10.399999618530273, 9.800000190734863, 16.600000381469727, 19.600000381469727, 24.799999237060547, 26.399999618530273, 2.0, 2.200000047683716, 13.0, 15.59999942779541, 24.600000381469727, 31.399999618530273, 11.800000190734863, 14.399999618530273, 10.0, 10.59999942779541, 20.19999885559082, 23.600000381469727] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [3.0, 3.0, 10.399999618530273, 8.0, 9.0, 11.0, 13.199999809265137, 10.399999618530273, 12.0, 8.800000190734863, 10.399999618530273, 9.800000190734863, 2.0, 2.200000047683716, 11.800000190734863, 14.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [12.799999237060547, 9.0, 17.600000381469727, 21.19999885559082, 15.399999618530273, 15.799999237060547, 13.59999942779541, 12.59999942779541, 8.59999942779541, 8.0, 16.600000381469727, 19.600000381469727, 13.0, 15.59999942779541, 10.0, 10.59999942779541] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [11.199999809265137, 12.0, 12.399999618530273, 13.799999237060547, 15.199999809265137, 17.0, 11.800000190734863, 11.399999618530273, 15.399999618530273, 17.0, 24.799999237060547, 26.399999618530273, 24.600000381469727, 31.399999618530273, 20.19999885559082, 23.600000381469727] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CoordRing6_9 | 13.49+/- 0.8864 (max: 31.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 18.01+/- 1.589 (max: 31.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 8.712+/- 1.002 (max: 14.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 13.75+/- 0.9831 (max: 21.2) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 11.45+/- 0.3211 (max: 16.73) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.08+/- 0.5 (max: 14.77) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 10.94+/- 0.6586 (max: 14.55) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 12.34+/- 0.4533 (max: 16.73) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.09854+/- 0.01743 (max: 0.56) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.1444+/- 0.04506 (max: 0.56) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.05375+/- 0.01052 (max: 0.14) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0975+/- 0.02101 (max: 0.3) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 2.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 11.2 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 8.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 6.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 8.417 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 9.798 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating ACCEL_CNN-S5_SEED1 against population in Overcooked-ForcedCoord6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.19999998807907104, 0.0, 0.0, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 0.0, 0.0, 0.3999999761581421, 0.19999998807907104, 0.5999999642372131, 0.3999999761581421, 0.5999999642372131, 0.19999998807907104, 1.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.19999998807907104, 0.0, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 0.0, 0.0, 0.3999999761581421, 0.19999998807907104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.0, 0.19999998807907104, 0.5999999642372131, 0.0, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 0.5999999642372131, 0.3999999761581421, 0.0, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 0.0, 0.0, 0.19999998807907104, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.19999998807907104, 0.0, 0.3999999761581421, 0.5999999642372131, 0.3999999761581421, 0.19999998807907104, 0.19999998807907104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.5999999642372131] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.19999998807907104, 0.19999998807907104, 0.0, 0.3999999761581421, 0.19999998807907104, 1.0, 0.0, 0.19999998807907104, 0.0, 0.3999999761581421, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0, 0.19999998807907104] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.175+/- 0.03073 (max: 1.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.1875+/- 0.06447 (max: 1.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.1625+/- 0.04171 (max: 0.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.175+/- 0.05439 (max: 0.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 1.342+/- 0.1859 (max: 4.359) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 1.369+/- 0.3444 (max: 4.359) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 1.383+/- 0.2926 (max: 3.412) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 1.274+/- 0.3466 (max: 3.412) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating ACCEL_CNN-S5_SEED1 against population in Overcooked-CounterCircuit6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [14.59999942779541, 11.199999809265137, 33.0, 18.399999618530273, 17.19999885559082, 12.799999237060547, 6.0, 6.399999618530273, 13.199999809265137, 9.0, 2.799999952316284, 0.3999999761581421, 1.5999999046325684, 0.5999999642372131, 5.0, 3.799999952316284, 4.400000095367432, 3.0, 9.59999942779541, 6.799999713897705, 26.399999618530273, 21.19999885559082, 11.199999809265137, 6.399999618530273, 12.59999942779541, 10.800000190734863, 23.600000381469727, 17.399999618530273, 5.799999713897705, 3.0, 14.199999809265137, 12.59999942779541, 23.19999885559082, 15.59999942779541, 11.800000190734863, 7.199999809265137, 10.0, 7.799999713897705, 20.600000381469727, 13.799999237060547, 11.800000190734863, 6.799999713897705, 16.19999885559082, 12.199999809265137, 38.39999771118164, 28.799999237060547, 11.0, 6.799999713897705] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [14.59999942779541, 11.199999809265137, 6.0, 6.399999618530273, 1.5999999046325684, 0.5999999642372131, 9.59999942779541, 6.799999713897705, 12.59999942779541, 10.800000190734863, 14.199999809265137, 12.59999942779541, 10.0, 7.799999713897705, 16.19999885559082, 12.199999809265137] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [33.0, 18.399999618530273, 13.199999809265137, 9.0, 5.0, 3.799999952316284, 26.399999618530273, 21.19999885559082, 23.600000381469727, 17.399999618530273, 23.19999885559082, 15.59999942779541, 20.600000381469727, 13.799999237060547, 38.39999771118164, 28.799999237060547] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [17.19999885559082, 12.799999237060547, 2.799999952316284, 0.3999999761581421, 4.400000095367432, 3.0, 11.199999809265137, 6.399999618530273, 5.799999713897705, 3.0, 11.800000190734863, 7.199999809265137, 11.800000190734863, 6.799999713897705, 11.0, 6.799999713897705] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 12.23+/- 1.203 (max: 38.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 7.65+/- 1.141 (max: 17.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 9.575+/- 1.114 (max: 16.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 19.46+/- 2.388 (max: 38.4) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 11.78+/- 0.5775 (max: 21.97) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 9.269+/- 0.5789 (max: 12.17) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 10.74+/- 0.704 (max: 13.52) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 15.35+/- 0.9815 (max: 21.97) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.1144+/- 0.02353 (max: 0.62) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.02+/- 0.007528 (max: 0.11) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.05062+/- 0.01142 (max: 0.14) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.2725+/- 0.04985 (max: 0.62) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.6 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 3.8 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 2.8 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 2.8 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 3.412 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 7.846 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-S5_SEED1 against population in Overcooked-AsymmAdvantages6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [1.7999999523162842, 2.5999999046325684, 0.5999999642372131, 11.0, 1.399999976158142, 3.5999999046325684, 1.7999999523162842, 6.599999904632568, 1.0, 45.599998474121094, 1.399999976158142, 3.5999999046325684, 1.399999976158142, 41.0, 1.5999999046325684, 57.79999923706055, 1.7999999523162842, 13.59999942779541, 1.0, 17.799999237060547, 1.7999999523162842, 10.0, 2.3999998569488525, 6.399999618530273, 1.1999999284744263, 3.1999998092651367, 1.5999999046325684, 19.0, 1.5999999046325684, 3.799999952316284, 0.19999998807907104, 2.200000047683716, 1.7999999523162842, 33.0, 2.0, 4.400000095367432, 1.0, 2.5999999046325684, 1.7999999523162842, 18.0, 1.1999999284744263, 7.599999904632568, 0.5999999642372131, 5.400000095367432, 0.5999999642372131, 20.600000381469727, 1.399999976158142, 2.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [1.7999999523162842, 2.5999999046325684, 1.7999999523162842, 6.599999904632568, 1.399999976158142, 41.0, 1.0, 17.799999237060547, 1.1999999284744263, 3.1999998092651367, 0.19999998807907104, 2.200000047683716, 1.0, 2.5999999046325684, 0.5999999642372131, 5.400000095367432] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.5999999642372131, 11.0, 1.0, 45.599998474121094, 1.5999999046325684, 57.79999923706055, 1.7999999523162842, 10.0, 1.5999999046325684, 19.0, 1.7999999523162842, 33.0, 1.7999999523162842, 18.0, 0.5999999642372131, 20.600000381469727] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [1.399999976158142, 3.5999999046325684, 1.399999976158142, 3.5999999046325684, 1.7999999523162842, 13.59999942779541, 2.3999998569488525, 6.399999618530273, 1.5999999046325684, 3.799999952316284, 2.0, 4.400000095367432, 1.1999999284744263, 7.599999904632568, 1.399999976158142, 2.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 7.8+/- 1.813 (max: 57.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 3.637+/- 0.8129 (max: 13.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 5.65+/- 2.583 (max: 41.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 14.11+/- 4.411 (max: 57.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 9.557+/- 0.9995 (max: 30.45) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 7.627+/- 0.6797 (max: 14.11) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 8.045+/- 1.619 (max: 28.34) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 13.0+/- 2.262 (max: 30.45) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.08854+/- 0.02773 (max: 0.78) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.01812+/- 0.00765 (max: 0.12) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.05688+/- 0.03978 (max: 0.61) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1906+/- 0.06703 (max: 0.78) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 1.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.6 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 4.75 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 3.412 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating ACCEL_CNN-S5_SEED1 against population in Overcooked-CrampedRoom6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [8.399999618530273, 7.0, 65.19999694824219, 63.19999694824219, 76.4000015258789, 74.79999542236328, 13.59999942779541, 12.0, 63.39999771118164, 58.39999771118164, 82.0, 72.79999542236328, 17.0, 15.799999237060547, 77.5999984741211, 73.4000015258789, 71.5999984741211, 64.19999694824219, 15.799999237060547, 15.0, 70.0, 68.5999984741211, 53.39999771118164, 52.599998474121094, 8.199999809265137, 8.59999942779541, 66.4000015258789, 62.599998474121094, 67.0, 73.19999694824219, 16.799999237060547, 16.0, 65.5999984741211, 69.4000015258789, 83.19999694824219, 88.5999984741211, 16.19999885559082, 13.199999809265137, 70.19999694824219, 69.79999542236328, 73.19999694824219, 75.5999984741211, 17.19999885559082, 14.799999237060547, 75.4000015258789, 71.0, 73.79999542236328, 75.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [8.399999618530273, 7.0, 13.59999942779541, 12.0, 17.0, 15.799999237060547, 15.799999237060547, 15.0, 8.199999809265137, 8.59999942779541, 16.799999237060547, 16.0, 16.19999885559082, 13.199999809265137, 17.19999885559082, 14.799999237060547] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [65.19999694824219, 63.19999694824219, 63.39999771118164, 58.39999771118164, 77.5999984741211, 73.4000015258789, 70.0, 68.5999984741211, 66.4000015258789, 62.599998474121094, 65.5999984741211, 69.4000015258789, 70.19999694824219, 69.79999542236328, 75.4000015258789, 71.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [76.4000015258789, 74.79999542236328, 82.0, 72.79999542236328, 71.5999984741211, 64.19999694824219, 53.39999771118164, 52.599998474121094, 67.0, 73.19999694824219, 83.19999694824219, 88.5999984741211, 73.19999694824219, 75.5999984741211, 73.79999542236328, 75.79999542236328] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 51.33+/- 4.021 (max: 88.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 72.39+/- 2.388 (max: 88.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 13.47+/- 0.8841 (max: 17.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 68.14+/- 1.26 (max: 77.6) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 25.21+/- 1.261 (max: 35.47) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 31.18+/- 0.5428 (max: 34.14) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.42+/- 0.4522 (max: 15.59) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 31.04+/- 0.7315 (max: 35.47) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.621+/- 0.05182 (max: 0.96) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.8762+/- 0.01604 (max: 0.96) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.125+/- 0.01725 (max: 0.22) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.8619+/- 0.008427 (max: 0.93) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 7.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 52.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 7.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 58.4 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.34 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 27.66 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.34 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 24.96 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.73 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.77 | +-------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-S5_SEED2 against population in Overcooked-CoordRing6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [1.7999999523162842, 3.0, 5.599999904632568, 6.599999904632568, 1.7999999523162842, 6.599999904632568, 4.799999713897705, 8.59999942779541, 9.399999618530273, 10.59999942779541, 3.1999998092651367, 7.199999809265137, 5.199999809265137, 6.199999809265137, 7.399999618530273, 7.599999904632568, 4.599999904632568, 5.199999809265137, 5.199999809265137, 3.5999999046325684, 4.599999904632568, 7.399999618530273, 3.1999998092651367, 4.599999904632568, 7.399999618530273, 8.199999809265137, 3.5999999046325684, 6.0, 1.7999999523162842, 3.5999999046325684, 5.599999904632568, 6.0, 10.199999809265137, 13.199999809265137, 14.59999942779541, 16.600000381469727, 1.1999999284744263, 1.0, 8.800000190734863, 8.800000190734863, 14.59999942779541, 15.0, 5.599999904632568, 8.199999809265137, 6.0, 7.199999809265137, 11.399999618530273, 13.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [1.7999999523162842, 3.0, 4.799999713897705, 8.59999942779541, 5.199999809265137, 6.199999809265137, 5.199999809265137, 3.5999999046325684, 7.399999618530273, 8.199999809265137, 5.599999904632568, 6.0, 1.1999999284744263, 1.0, 5.599999904632568, 8.199999809265137] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [5.599999904632568, 6.599999904632568, 9.399999618530273, 10.59999942779541, 7.399999618530273, 7.599999904632568, 4.599999904632568, 7.399999618530273, 3.5999999046325684, 6.0, 10.199999809265137, 13.199999809265137, 8.800000190734863, 8.800000190734863, 6.0, 7.199999809265137] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [1.7999999523162842, 6.599999904632568, 3.1999998092651367, 7.199999809265137, 4.599999904632568, 5.199999809265137, 3.1999998092651367, 4.599999904632568, 1.7999999523162842, 3.5999999046325684, 14.59999942779541, 16.600000381469727, 14.59999942779541, 15.0, 11.399999618530273, 13.399999618530273] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 6.917+/- 0.5506 (max: 16.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 7.962+/- 1.333 (max: 16.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 5.1+/- 0.6088 (max: 8.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 7.687+/- 0.6072 (max: 13.2) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 9.568+/- 0.2986 (max: 13.03) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 9.48+/- 0.5955 (max: 12.92) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 8.839+/- 0.5658 (max: 11.75) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 10.39+/- 0.2874 (max: 13.03) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.02437+/- 0.004334 (max: 0.11) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0325+/- 0.01074 (max: 0.11) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.01562+/- 0.00418 (max: 0.05) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.025+/- 0.005845 (max: 0.1) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 1.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 1.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 1.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 3.6 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 4.359 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 5.724 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 4.359 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 8.188 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-S5_SEED2 against population in Overcooked-ForcedCoord6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [1.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.3999999761581421, 0.7999999523162842, 0.3999999761581421, 0.3999999761581421, 0.19999998807907104, 0.5999999642372131, 0.0, 0.7999999523162842, 0.19999998807907104, 0.0, 0.0, 0.7999999523162842, 0.3999999761581421, 0.5999999642372131, 0.0, 0.7999999523162842, 0.3999999761581421, 1.399999976158142, 0.0, 0.7999999523162842, 0.0, 0.3999999761581421, 0.0, 0.5999999642372131, 0.0, 0.3999999761581421, 0.19999998807907104, 0.0, 0.5999999642372131, 0.0, 0.0, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 1.0, 1.0, 0.0, 2.0, 0.19999998807907104, 2.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [1.0, 0.0, 1.0, 0.3999999761581421, 0.5999999642372131, 0.0, 0.7999999523162842, 0.3999999761581421, 1.399999976158142, 0.0, 0.5999999642372131, 0.0, 0.0, 0.0, 1.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [1.0, 0.0, 0.7999999523162842, 0.3999999761581421, 0.7999999523162842, 0.19999998807907104, 0.5999999642372131, 0.0, 0.7999999523162842, 0.0, 0.3999999761581421, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 2.0, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.3999999761581421, 0.19999998807907104, 0.0, 0.0, 0.7999999523162842, 0.3999999761581421, 0.3999999761581421, 0.0, 0.0, 0.5999999642372131, 0.19999998807907104, 1.0, 2.0, 0.0] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.4375+/- 0.07236 (max: 2.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.375+/- 0.134 (max: 2.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.45+/- 0.119 (max: 1.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.4875+/- 0.1291 (max: 2.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 2.201+/- 0.2716 (max: 6.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 1.879+/- 0.487 (max: 6.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 2.158+/- 0.5118 (max: 5.103) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 2.567+/- 0.422 (max: 6.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating ACCEL_CNN-S5_SEED2 against population in Overcooked-CounterCircuit6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [14.59999942779541, 10.59999942779541, 35.0, 22.600000381469727, 28.599998474121094, 27.0, 13.59999942779541, 10.800000190734863, 20.399999618530273, 13.799999237060547, 6.599999904632568, 0.7999999523162842, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 1.5999999046325684, 2.5999999046325684, 10.0, 10.800000190734863, 27.799999237060547, 22.19999885559082, 15.199999809265137, 5.599999904632568, 15.0, 12.0, 32.79999923706055, 25.599998474121094, 8.0, 3.3999998569488525, 15.0, 15.0, 34.0, 23.799999237060547, 15.199999809265137, 5.599999904632568, 16.19999885559082, 8.800000190734863, 43.0, 19.0, 9.399999618530273, 3.1999998092651367, 16.399999618530273, 9.800000190734863, 44.599998474121094, 23.19999885559082, 9.800000190734863, 5.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [14.59999942779541, 10.59999942779541, 13.59999942779541, 10.800000190734863, 0.3999999761581421, 0.0, 10.0, 10.800000190734863, 15.0, 12.0, 15.0, 15.0, 16.19999885559082, 8.800000190734863, 16.399999618530273, 9.800000190734863] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [35.0, 22.600000381469727, 20.399999618530273, 13.799999237060547, 0.3999999761581421, 0.0, 27.799999237060547, 22.19999885559082, 32.79999923706055, 25.599998474121094, 34.0, 23.799999237060547, 43.0, 19.0, 44.599998474121094, 23.19999885559082] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [28.599998474121094, 27.0, 6.599999904632568, 0.7999999523162842, 1.5999999046325684, 2.5999999046325684, 15.199999809265137, 5.599999904632568, 8.0, 3.3999998569488525, 15.199999809265137, 5.599999904632568, 9.399999618530273, 3.1999998092651367, 9.800000190734863, 5.0] +---------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 14.89+/- 1.623 (max: 44.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 9.225+/- 2.1 (max: 28.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 11.19+/- 1.234 (max: 16.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 24.26+/- 3.149 (max: 44.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 11.82+/- 0.7045 (max: 19.93) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 9.154+/- 0.6016 (max: 12.45) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 10.99+/- 0.979 (max: 15.19) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 15.32+/- 1.418 (max: 19.93) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.1773+/- 0.03269 (max: 0.81) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.06625+/- 0.0372 (max: 0.46) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.07187+/- 0.0117 (max: 0.17) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.3937+/- 0.06194 (max: 0.81) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.8 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 3.919 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +---------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-S5_SEED2 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.3999999761581421, 1.399999976158142, 0.0, 5.400000095367432, 0.0, 0.0, 0.0, 3.1999998092651367, 0.0, 36.20000076293945, 0.19999998807907104, 0.0, 0.19999998807907104, 31.19999885559082, 0.3999999761581421, 48.20000076293945, 0.3999999761581421, 7.599999904632568, 0.3999999761581421, 11.399999618530273, 0.3999999761581421, 5.400000095367432, 0.19999998807907104, 0.3999999761581421, 0.19999998807907104, 1.0, 0.0, 13.399999618530273, 0.19999998807907104, 0.7999999523162842, 0.0, 1.1999999284744263, 0.3999999761581421, 24.19999885559082, 0.19999998807907104, 0.3999999761581421, 0.19999998807907104, 1.399999976158142, 0.0, 13.399999618530273, 0.19999998807907104, 1.7999999523162842, 0.5999999642372131, 3.799999952316284, 0.0, 18.399999618530273, 0.3999999761581421, 0.19999998807907104] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.3999999761581421, 1.399999976158142, 0.0, 3.1999998092651367, 0.19999998807907104, 31.19999885559082, 0.3999999761581421, 11.399999618530273, 0.19999998807907104, 1.0, 0.0, 1.1999999284744263, 0.19999998807907104, 1.399999976158142, 0.5999999642372131, 3.799999952316284] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 5.400000095367432, 0.0, 36.20000076293945, 0.3999999761581421, 48.20000076293945, 0.3999999761581421, 5.400000095367432, 0.0, 13.399999618530273, 0.3999999761581421, 24.19999885559082, 0.0, 13.399999618530273, 0.0, 18.399999618530273] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 0.0, 0.19999998807907104, 0.0, 0.3999999761581421, 7.599999904632568, 0.19999998807907104, 0.3999999761581421, 0.19999998807907104, 0.7999999523162842, 0.19999998807907104, 0.3999999761581421, 0.19999998807907104, 1.7999999523162842, 0.3999999761581421, 0.19999998807907104] +-------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 4.904+/- 1.492 (max: 48.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.8125+/- 0.4653 (max: 7.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 3.537+/- 1.974 (max: 31.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 10.36+/- 3.684 (max: 48.2) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 6.068+/- 1.091 (max: 31.38) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 2.697+/- 0.6093 (max: 9.708) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 5.737+/- 1.62 (max: 25.51) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 9.771+/- 2.553 (max: 31.38) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.05958+/- 0.02192 (max: 0.67) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.03937+/- 0.02983 (max: 0.47) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1387+/- 0.0543 (max: 0.67) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +-------------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-S5_SEED2 against population in Overcooked-CrampedRoom6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [9.800000190734863, 6.799999713897705, 58.599998474121094, 54.39999771118164, 67.4000015258789, 67.5999984741211, 9.399999618530273, 10.800000190734863, 44.599998474121094, 35.39999771118164, 51.19999694824219, 44.0, 12.799999237060547, 11.0, 54.0, 48.0, 57.79999923706055, 60.0, 13.799999237060547, 13.199999809265137, 55.0, 51.39999771118164, 39.599998474121094, 40.0, 9.800000190734863, 8.800000190734863, 37.79999923706055, 35.79999923706055, 35.39999771118164, 35.20000076293945, 16.600000381469727, 15.199999809265137, 51.39999771118164, 43.0, 50.0, 48.39999771118164, 10.399999618530273, 9.0, 50.0, 46.20000076293945, 43.79999923706055, 42.0, 16.19999885559082, 14.799999237060547, 53.19999694824219, 50.20000076293945, 46.39999771118164, 54.79999923706055] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [9.800000190734863, 6.799999713897705, 9.399999618530273, 10.800000190734863, 12.799999237060547, 11.0, 13.799999237060547, 13.199999809265137, 9.800000190734863, 8.800000190734863, 16.600000381469727, 15.199999809265137, 10.399999618530273, 9.0, 16.19999885559082, 14.799999237060547] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [58.599998474121094, 54.39999771118164, 44.599998474121094, 35.39999771118164, 54.0, 48.0, 55.0, 51.39999771118164, 37.79999923706055, 35.79999923706055, 51.39999771118164, 43.0, 50.0, 46.20000076293945, 53.19999694824219, 50.20000076293945] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [67.4000015258789, 67.5999984741211, 51.19999694824219, 44.0, 57.79999923706055, 60.0, 39.599998474121094, 40.0, 35.39999771118164, 35.20000076293945, 50.0, 48.39999771118164, 43.79999923706055, 42.0, 46.39999771118164, 54.79999923706055] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 36.27+/- 2.734 (max: 67.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 48.97+/- 2.562 (max: 67.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 11.77+/- 0.7335 (max: 16.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 48.06+/- 1.766 (max: 58.6) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 24.84+/- 1.26 (max: 37.12) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 30.03+/- 0.6819 (max: 33.59) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.28+/- 0.4056 (max: 15.63) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 31.22+/- 0.9342 (max: 37.12) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.4892+/- 0.04298 (max: 0.92) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.6844+/- 0.03431 (max: 0.92) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.09937+/- 0.0128 (max: 0.19) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.6837+/- 0.02876 (max: 0.91) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 6.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 35.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 6.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 35.4 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.28 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 25.1 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.28 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 23.2 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.52 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.49 | +------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-S5_SEED3 against population in Overcooked-CoordRing6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [2.799999952316284, 4.0, 14.0, 14.0, 11.800000190734863, 12.0, 12.399999618530273, 13.399999618530273, 22.399999618530273, 24.19999885559082, 9.59999942779541, 13.199999809265137, 9.59999942779541, 12.59999942779541, 12.799999237060547, 16.0, 10.59999942779541, 12.799999237060547, 12.199999809265137, 9.800000190734863, 16.19999885559082, 15.199999809265137, 11.0, 13.59999942779541, 11.399999618530273, 11.59999942779541, 10.0, 11.399999618530273, 12.0, 13.399999618530273, 8.0, 8.800000190734863, 18.19999885559082, 19.600000381469727, 23.19999885559082, 25.0, 3.5999999046325684, 2.5999999046325684, 11.59999942779541, 10.399999618530273, 19.399999618530273, 22.19999885559082, 7.199999809265137, 4.799999713897705, 9.199999809265137, 8.399999618530273, 7.799999713897705, 9.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [2.799999952316284, 4.0, 12.399999618530273, 13.399999618530273, 9.59999942779541, 12.59999942779541, 12.199999809265137, 9.800000190734863, 11.399999618530273, 11.59999942779541, 8.0, 8.800000190734863, 3.5999999046325684, 2.5999999046325684, 7.199999809265137, 4.799999713897705] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [14.0, 14.0, 22.399999618530273, 24.19999885559082, 12.799999237060547, 16.0, 16.19999885559082, 15.199999809265137, 10.0, 11.399999618530273, 18.19999885559082, 19.600000381469727, 11.59999942779541, 10.399999618530273, 9.199999809265137, 8.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [11.800000190734863, 12.0, 9.59999942779541, 13.199999809265137, 10.59999942779541, 12.799999237060547, 11.0, 13.59999942779541, 12.0, 13.399999618530273, 23.19999885559082, 25.0, 19.399999618530273, 22.19999885559082, 7.799999713897705, 9.399999618530273] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CoordRing6_9 | 12.4+/- 0.771 (max: 25.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 14.19+/- 1.317 (max: 25.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 8.425+/- 0.9527 (max: 13.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 14.6+/- 1.164 (max: 24.2) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 11.35+/- 0.2857 (max: 15.59) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 10.94+/- 0.4266 (max: 15.59) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 10.85+/- 0.5935 (max: 14.16) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 12.26+/- 0.3876 (max: 15.26) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.07625+/- 0.01365 (max: 0.38) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.075+/- 0.02958 (max: 0.38) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.045+/- 0.009531 (max: 0.12) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.1087+/- 0.02551 (max: 0.33) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 2.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 7.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 8.4 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 6.726 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 9.145 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.726 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 10.27 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.01 | +------------------------------------------------------------------------------------------------ +Evaluating ACCEL_CNN-S5_SEED3 against population in Overcooked-ForcedCoord6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [2.0, 0.0, 4.400000095367432, 0.0, 1.0, 0.0, 4.799999713897705, 0.0, 6.199999809265137, 0.0, 4.400000095367432, 0.0, 5.400000095367432, 0.0, 7.399999618530273, 0.0, 6.199999809265137, 0.0, 4.199999809265137, 0.0, 3.799999952316284, 0.0, 1.5999999046325684, 0.0, 1.1999999284744263, 0.0, 2.0, 0.0, 2.3999998569488525, 0.0, 4.0, 0.0, 5.799999713897705, 0.0, 2.0, 0.0, 3.5999999046325684, 0.0, 2.799999952316284, 0.0, 4.599999904632568, 0.0, 2.200000047683716, 0.0, 1.7999999523162842, 0.0, 4.199999809265137, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [2.0, 0.0, 4.799999713897705, 0.0, 5.400000095367432, 0.0, 4.199999809265137, 0.0, 1.1999999284744263, 0.0, 4.0, 0.0, 3.5999999046325684, 0.0, 2.200000047683716, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [4.400000095367432, 0.0, 6.199999809265137, 0.0, 7.399999618530273, 0.0, 3.799999952316284, 0.0, 2.0, 0.0, 5.799999713897705, 0.0, 2.799999952316284, 0.0, 1.7999999523162842, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [1.0, 0.0, 4.400000095367432, 0.0, 6.199999809265137, 0.0, 1.5999999046325684, 0.0, 2.3999998569488525, 0.0, 2.0, 0.0, 4.599999904632568, 0.0, 4.199999809265137, 0.0] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.833+/- 0.3216 (max: 7.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.65+/- 0.5258 (max: 6.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.712+/- 0.5089 (max: 5.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 2.137+/- 0.6539 (max: 7.4) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 3.782+/- 0.5769 (max: 10.45) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 3.587+/- 0.9751 (max: 9.673) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 3.75+/- 1.012 (max: 9.739) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 4.008+/- 1.071 (max: 10.45) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.001875+/- 0.0007682 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0025+/- 0.001708 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.001875+/- 0.00136 (max: 0.02) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-S5_SEED3 against population in Overcooked-CounterCircuit6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [0.3999999761581421, 0.3999999761581421, 1.5999999046325684, 0.7999999523162842, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 0.5999999642372131, 0.19999998807907104, 0.0, 0.0, 0.5999999642372131, 0.19999998807907104, 2.200000047683716, 1.399999976158142, 0.7999999523162842, 0.3999999761581421, 1.0, 0.0, 5.599999904632568, 2.3999998569488525, 0.7999999523162842, 0.0, 0.3999999761581421, 0.3999999761581421, 4.199999809265137, 2.0, 0.19999998807907104, 0.0, 1.399999976158142, 0.0, 1.399999976158142, 0.3999999761581421, 1.0, 0.0, 0.19999998807907104, 0.0, 0.7999999523162842, 0.3999999761581421, 0.0, 0.0, 0.5999999642372131, 0.0, 1.399999976158142, 0.7999999523162842, 0.19999998807907104, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [0.3999999761581421, 0.3999999761581421, 0.19999998807907104, 0.19999998807907104, 0.5999999642372131, 0.19999998807907104, 1.0, 0.0, 0.3999999761581421, 0.3999999761581421, 1.399999976158142, 0.0, 0.19999998807907104, 0.0, 0.5999999642372131, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [1.5999999046325684, 0.7999999523162842, 0.5999999642372131, 0.19999998807907104, 2.200000047683716, 1.399999976158142, 5.599999904632568, 2.3999998569488525, 4.199999809265137, 2.0, 1.399999976158142, 0.3999999761581421, 0.7999999523162842, 0.3999999761581421, 1.399999976158142, 0.7999999523162842] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [0.19999998807907104, 0.19999998807907104, 0.0, 0.0, 0.7999999523162842, 0.3999999761581421, 0.7999999523162842, 0.0, 0.19999998807907104, 0.0, 1.0, 0.0, 0.0, 0.0, 0.19999998807907104, 0.0] +--------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.75+/- 0.1558 (max: 5.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.2375+/- 0.0841 (max: 1.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.375+/- 0.09639 (max: 1.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 1.637+/- 0.3639 (max: 5.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 2.87+/- 0.328 (max: 8.98) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 1.435+/- 0.4106 (max: 4.359) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 2.215+/- 0.3926 (max: 5.103) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 4.961+/- 0.481 (max: 8.98) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.2 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 1.99 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +--------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-S5_SEED3 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 6.0, 0.0, 15.799999237060547, 0.3999999761581421, 11.0, 0.0, 8.59999942779541, 0.0, 44.39999771118164, 0.19999998807907104, 2.5999999046325684, 0.0, 40.599998474121094, 0.0, 63.599998474121094, 0.19999998807907104, 10.199999809265137, 0.0, 18.0, 0.0, 15.0, 0.0, 4.599999904632568, 0.0, 6.799999713897705, 0.0, 21.799999237060547, 0.0, 7.799999713897705, 0.0, 6.199999809265137, 0.0, 30.799999237060547, 0.0, 1.7999999523162842, 0.19999998807907104, 7.799999713897705, 0.0, 17.19999885559082, 0.19999998807907104, 7.199999809265137, 0.0, 7.199999809265137, 0.0, 25.19999885559082, 0.0, 6.399999618530273] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 6.0, 0.0, 8.59999942779541, 0.0, 40.599998474121094, 0.0, 18.0, 0.0, 6.799999713897705, 0.0, 6.199999809265137, 0.19999998807907104, 7.799999713897705, 0.0, 7.199999809265137] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 15.799999237060547, 0.0, 44.39999771118164, 0.0, 63.599998474121094, 0.0, 15.0, 0.0, 21.799999237060547, 0.0, 30.799999237060547, 0.0, 17.19999885559082, 0.0, 25.19999885559082] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.3999999761581421, 11.0, 0.19999998807907104, 2.5999999046325684, 0.19999998807907104, 10.199999809265137, 0.0, 4.599999904632568, 0.0, 7.799999713897705, 0.0, 1.7999999523162842, 0.19999998807907104, 7.199999809265137, 0.0, 6.399999618530273] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 8.079+/- 1.93 (max: 63.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 3.287+/- 0.9941 (max: 11.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 6.337+/- 2.612 (max: 40.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 14.61+/- 4.756 (max: 63.6) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 7.927+/- 1.275 (max: 31.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 5.898+/- 1.337 (max: 14.25) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 7.021+/- 2.041 (max: 27.78) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 10.86+/- 2.903 (max: 31.8) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.09708+/- 0.02841 (max: 0.83) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.02062+/- 0.007983 (max: 0.09) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.06812+/- 0.04038 (max: 0.63) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.2025+/- 0.06859 (max: 0.83) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating ACCEL_CNN-S5_SEED3 against population in Overcooked-CrampedRoom6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [7.799999713897705, 7.599999904632568, 58.79999923706055, 60.19999694824219, 86.5999984741211, 91.0, 14.59999942779541, 15.799999237060547, 68.79999542236328, 67.4000015258789, 84.5999984741211, 88.4000015258789, 16.799999237060547, 17.0, 89.0, 89.0, 86.5999984741211, 90.19999694824219, 13.199999809265137, 13.399999618530273, 69.19999694824219, 73.5999984741211, 69.0, 68.79999542236328, 7.799999713897705, 8.800000190734863, 65.0, 63.19999694824219, 76.0, 79.0, 14.199999809265137, 14.59999942779541, 73.5999984741211, 75.5999984741211, 83.0, 85.79999542236328, 17.600000381469727, 15.199999809265137, 60.39999771118164, 66.4000015258789, 86.19999694824219, 85.4000015258789, 16.0, 17.399999618530273, 72.5999984741211, 71.0, 80.5999984741211, 83.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [7.799999713897705, 7.599999904632568, 14.59999942779541, 15.799999237060547, 16.799999237060547, 17.0, 13.199999809265137, 13.399999618530273, 7.799999713897705, 8.800000190734863, 14.199999809265137, 14.59999942779541, 17.600000381469727, 15.199999809265137, 16.0, 17.399999618530273] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [58.79999923706055, 60.19999694824219, 68.79999542236328, 67.4000015258789, 89.0, 89.0, 69.19999694824219, 73.5999984741211, 65.0, 63.19999694824219, 73.5999984741211, 75.5999984741211, 60.39999771118164, 66.4000015258789, 72.5999984741211, 71.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [86.5999984741211, 91.0, 84.5999984741211, 88.4000015258789, 86.5999984741211, 90.19999694824219, 69.0, 68.79999542236328, 76.0, 79.0, 83.0, 85.79999542236328, 86.19999694824219, 85.4000015258789, 80.5999984741211, 83.0] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 55.54+/- 4.491 (max: 91.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 82.76+/- 1.664 (max: 91.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 13.61+/- 0.8997 (max: 17.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 70.24+/- 2.229 (max: 89.0) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 19.7+/- 0.7193 (max: 27.94) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 22.86+/- 0.5345 (max: 25.83) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.36+/- 0.3594 (max: 14.93) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 22.87+/- 0.654 (max: 27.94) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6775+/- 0.05719 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9762+/- 0.005468 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1269+/- 0.0158 (max: 0.2) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9294+/- 0.008731 (max: 0.98) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 7.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 68.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 7.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 58.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.87 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 19.26 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.87 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 18.29 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.93 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.84 | +-------------------------------------------------------------------------------------------------- diff --git a/src/run_results_txt/eval_xpid_all_softmoe_out.txt b/src/run_results_txt/eval_xpid_all_softmoe_out.txt new file mode 100644 index 0000000..8b93e44 --- /dev/null +++ b/src/run_results_txt/eval_xpid_all_softmoe_out.txt @@ -0,0 +1,2280 @@ +Evaluating DR_SoftMoE_SEED1 against population in Overcooked-CoordRing6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [4.199999809265137, 3.5999999046325684, 3.799999952316284, 3.0, 1.399999976158142, 3.0, 3.3999998569488525, 4.799999713897705, 5.599999904632568, 9.399999618530273, 2.5999999046325684, 3.0, 6.799999713897705, 7.799999713897705, 4.0, 5.0, 2.5999999046325684, 2.799999952316284, 13.799999237060547, 7.199999809265137, 11.0, 10.199999809265137, 3.0, 0.7999999523162842, 3.1999998092651367, 2.799999952316284, 1.0, 3.3999998569488525, 2.3999998569488525, 2.200000047683716, 14.399999618530273, 14.399999618530273, 21.19999885559082, 21.0, 5.599999904632568, 4.799999713897705, 3.1999998092651367, 2.3999998569488525, 11.399999618530273, 16.600000381469727, 13.799999237060547, 21.0, 9.0, 13.59999942779541, 8.199999809265137, 9.199999809265137, 4.599999904632568, 8.59999942779541] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [4.199999809265137, 3.5999999046325684, 3.3999998569488525, 4.799999713897705, 6.799999713897705, 7.799999713897705, 13.799999237060547, 7.199999809265137, 3.1999998092651367, 2.799999952316284, 14.399999618530273, 14.399999618530273, 3.1999998092651367, 2.3999998569488525, 9.0, 13.59999942779541] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [3.799999952316284, 3.0, 5.599999904632568, 9.399999618530273, 4.0, 5.0, 11.0, 10.199999809265137, 1.0, 3.3999998569488525, 21.19999885559082, 21.0, 11.399999618530273, 16.600000381469727, 8.199999809265137, 9.199999809265137] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [1.399999976158142, 3.0, 2.5999999046325684, 3.0, 2.5999999046325684, 2.799999952316284, 3.0, 0.7999999523162842, 2.3999998569488525, 2.200000047683716, 5.599999904632568, 4.799999713897705, 13.799999237060547, 21.0, 4.599999904632568, 8.59999942779541] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CoordRing6_9 | 7.1+/- 0.7939 (max: 21.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 5.137+/- 1.323 (max: 21.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 7.162+/- 1.134 (max: 14.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 9.0+/- 1.546 (max: 21.2) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 10.38+/- 0.4978 (max: 18.63) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 9.118+/- 0.947 (max: 18.63) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 10.27+/- 0.6658 (max: 14.39) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.74+/- 0.8718 (max: 17.52) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.06292+/- 0.01227 (max: 0.34) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.04687+/- 0.02196 (max: 0.34) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.04687+/- 0.01251 (max: 0.14) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.095+/- 0.02596 (max: 0.33) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 0.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 0.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 1.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 3.919 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 3.919 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.94 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 5.196 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.01 | +------------------------------------------------------------------------------------------------ +Evaluating DR_SoftMoE_SEED1 against population in Overcooked-ForcedCoord6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.3999999761581421, 0.3999999761581421, 0.3999999761581421, 1.7999999523162842, 0.5999999642372131, 2.0, 0.3999999761581421, 0.7999999523162842, 0.3999999761581421, 0.5999999642372131, 0.0, 1.0, 0.19999998807907104, 3.3999998569488525, 0.0, 5.400000095367432, 0.19999998807907104, 7.199999809265137, 0.3999999761581421, 0.3999999761581421, 0.3999999761581421, 1.0, 0.19999998807907104, 1.7999999523162842, 0.0, 2.0, 0.3999999761581421, 1.7999999523162842, 1.399999976158142, 1.1999999284744263, 0.19999998807907104, 0.3999999761581421, 0.19999998807907104, 0.19999998807907104, 0.19999998807907104, 0.5999999642372131, 0.19999998807907104, 0.3999999761581421, 0.3999999761581421, 0.19999998807907104, 0.7999999523162842, 0.19999998807907104, 0.3999999761581421, 1.1999999284744263, 0.3999999761581421, 0.3999999761581421, 1.1999999284744263, 1.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.3999999761581421, 0.3999999761581421, 0.3999999761581421, 0.7999999523162842, 0.19999998807907104, 3.3999998569488525, 0.3999999761581421, 0.3999999761581421, 0.0, 2.0, 0.19999998807907104, 0.3999999761581421, 0.19999998807907104, 0.3999999761581421, 0.3999999761581421, 1.1999999284744263] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.3999999761581421, 1.7999999523162842, 0.3999999761581421, 0.5999999642372131, 0.0, 5.400000095367432, 0.3999999761581421, 1.0, 0.3999999761581421, 1.7999999523162842, 0.19999998807907104, 0.19999998807907104, 0.3999999761581421, 0.19999998807907104, 0.3999999761581421, 0.3999999761581421] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.5999999642372131, 2.0, 0.0, 1.0, 0.19999998807907104, 7.199999809265137, 0.19999998807907104, 1.7999999523162842, 1.399999976158142, 1.1999999284744263, 0.19999998807907104, 0.5999999642372131, 0.7999999523162842, 0.19999998807907104, 1.1999999284744263, 1.0] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.9333+/- 0.1917 (max: 7.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.225+/- 0.4254 (max: 7.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.7+/- 0.216 (max: 3.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.875+/- 0.3291 (max: 5.4) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 3.491+/- 0.2874 (max: 9.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 3.959+/- 0.5544 (max: 9.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 3.159+/- 0.4342 (max: 7.513) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 3.354+/- 0.5072 (max: 8.879) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating DR_SoftMoE_SEED1 against population in Overcooked-CounterCircuit6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [6.399999618530273, 7.0, 10.59999942779541, 9.59999942779541, 6.399999618530273, 2.0, 5.0, 4.599999904632568, 6.199999809265137, 5.599999904632568, 1.0, 0.0, 0.0, 0.0, 1.0, 0.5999999642372131, 1.1999999284744263, 1.1999999284744263, 5.199999809265137, 3.0, 13.59999942779541, 10.59999942779541, 4.0, 2.3999998569488525, 7.399999618530273, 4.199999809265137, 9.800000190734863, 9.59999942779541, 1.5999999046325684, 1.399999976158142, 8.399999618530273, 3.3999998569488525, 10.199999809265137, 6.199999809265137, 5.599999904632568, 4.400000095367432, 6.399999618530273, 4.400000095367432, 7.399999618530273, 7.399999618530273, 7.599999904632568, 2.3999998569488525, 7.399999618530273, 5.799999713897705, 10.199999809265137, 10.0, 6.799999713897705, 1.1999999284744263] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [6.399999618530273, 7.0, 5.0, 4.599999904632568, 0.0, 0.0, 5.199999809265137, 3.0, 7.399999618530273, 4.199999809265137, 8.399999618530273, 3.3999998569488525, 6.399999618530273, 4.400000095367432, 7.399999618530273, 5.799999713897705] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [10.59999942779541, 9.59999942779541, 6.199999809265137, 5.599999904632568, 1.0, 0.5999999642372131, 13.59999942779541, 10.59999942779541, 9.800000190734863, 9.59999942779541, 10.199999809265137, 6.199999809265137, 7.399999618530273, 7.399999618530273, 10.199999809265137, 10.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [6.399999618530273, 2.0, 1.0, 0.0, 1.1999999284744263, 1.1999999284744263, 4.0, 2.3999998569488525, 1.5999999046325684, 1.399999976158142, 5.599999904632568, 4.400000095367432, 7.599999904632568, 2.3999998569488525, 6.799999713897705, 1.1999999284744263] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 5.342+/- 0.4983 (max: 13.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 3.075+/- 0.5977 (max: 7.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 4.912+/- 0.6099 (max: 8.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 8.037+/- 0.8747 (max: 13.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 8.861+/- 0.5576 (max: 16.37) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 6.63+/- 0.7181 (max: 11.59) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 8.289+/- 0.8692 (max: 12.71) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 11.67+/- 0.8794 (max: 16.37) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.02833+/- 0.005557 (max: 0.13) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.00625+/- 0.003521 (max: 0.05) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.01187+/- 0.003191 (max: 0.04) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.06687+/- 0.01087 (max: 0.13) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 0.6 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 3.412 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating DR_SoftMoE_SEED1 against population in Overcooked-AsymmAdvantages6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [10.59999942779541, 15.399999618530273, 9.0, 24.600000381469727, 16.19999885559082, 28.599998474121094, 10.59999942779541, 20.19999885559082, 14.199999809265137, 54.599998474121094, 11.59999942779541, 15.399999618530273, 16.19999885559082, 39.79999923706055, 14.199999809265137, 68.5999984741211, 15.199999809265137, 22.799999237060547, 12.0, 31.0, 13.59999942779541, 25.599998474121094, 13.799999237060547, 11.0, 11.800000190734863, 11.399999618530273, 10.399999618530273, 30.799999237060547, 13.799999237060547, 18.600000381469727, 12.199999809265137, 17.0, 12.0, 43.79999923706055, 13.199999809265137, 13.0, 10.800000190734863, 15.59999942779541, 15.799999237060547, 33.79999923706055, 12.59999942779541, 16.19999885559082, 9.399999618530273, 20.600000381469727, 8.59999942779541, 38.39999771118164, 18.19999885559082, 18.799999237060547] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [10.59999942779541, 15.399999618530273, 10.59999942779541, 20.19999885559082, 16.19999885559082, 39.79999923706055, 12.0, 31.0, 11.800000190734863, 11.399999618530273, 12.199999809265137, 17.0, 10.800000190734863, 15.59999942779541, 9.399999618530273, 20.600000381469727] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [9.0, 24.600000381469727, 14.199999809265137, 54.599998474121094, 14.199999809265137, 68.5999984741211, 13.59999942779541, 25.599998474121094, 10.399999618530273, 30.799999237060547, 12.0, 43.79999923706055, 15.799999237060547, 33.79999923706055, 8.59999942779541, 38.39999771118164] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [16.19999885559082, 28.599998474121094, 11.59999942779541, 15.399999618530273, 15.199999809265137, 22.799999237060547, 13.799999237060547, 11.0, 13.799999237060547, 18.600000381469727, 13.199999809265137, 13.0, 12.59999942779541, 16.19999885559082, 18.19999885559082, 18.799999237060547] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 19.62+/- 1.775 (max: 68.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 16.19+/- 1.132 (max: 28.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 16.54+/- 2.067 (max: 39.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 26.12+/- 4.452 (max: 68.6) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 19.33+/- 0.8651 (max: 35.24) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 19.19+/- 1.115 (max: 27.83) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 17.0+/- 1.186 (max: 30.0) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 21.81+/- 1.901 (max: 35.24) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.2521+/- 0.02769 (max: 0.83) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.1987+/- 0.01832 (max: 0.38) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.2056+/- 0.03832 (max: 0.61) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.3519+/- 0.06639 (max: 0.83) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 8.6 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 11.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 9.4 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 8.6 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 12.08 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 13.92 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 12.68 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 12.08 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.04 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.11 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.08 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.04 | +----------------------------------------------------------------------------------------------------- +Evaluating DR_SoftMoE_SEED1 against population in Overcooked-CrampedRoom6_9 for xpid dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [10.800000190734863, 8.399999618530273, 67.4000015258789, 51.19999694824219, 96.5999984741211, 88.5999984741211, 16.19999885559082, 15.59999942779541, 76.79999542236328, 61.19999694824219, 91.5999984741211, 85.4000015258789, 18.19999885559082, 15.59999942779541, 89.4000015258789, 83.5999984741211, 99.5999984741211, 91.79999542236328, 19.0, 16.600000381469727, 86.0, 73.19999694824219, 86.4000015258789, 79.19999694824219, 9.399999618530273, 10.0, 74.19999694824219, 62.39999771118164, 95.4000015258789, 88.0, 17.799999237060547, 14.59999942779541, 82.0, 67.5999984741211, 98.5999984741211, 90.19999694824219, 17.0, 15.399999618530273, 80.79999542236328, 67.5999984741211, 94.5999984741211, 91.0, 21.600000381469727, 16.799999237060547, 85.79999542236328, 72.4000015258789, 92.4000015258789, 90.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [10.800000190734863, 8.399999618530273, 16.19999885559082, 15.59999942779541, 18.19999885559082, 15.59999942779541, 19.0, 16.600000381469727, 9.399999618530273, 10.0, 17.799999237060547, 14.59999942779541, 17.0, 15.399999618530273, 21.600000381469727, 16.799999237060547] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [67.4000015258789, 51.19999694824219, 76.79999542236328, 61.19999694824219, 89.4000015258789, 83.5999984741211, 86.0, 73.19999694824219, 74.19999694824219, 62.39999771118164, 82.0, 67.5999984741211, 80.79999542236328, 67.5999984741211, 85.79999542236328, 72.4000015258789] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [96.5999984741211, 88.5999984741211, 91.5999984741211, 85.4000015258789, 99.5999984741211, 91.79999542236328, 86.4000015258789, 79.19999694824219, 95.4000015258789, 88.0, 98.5999984741211, 90.19999694824219, 94.5999984741211, 91.0, 92.4000015258789, 90.79999542236328] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 60.1+/- 4.853 (max: 99.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 91.26+/- 1.298 (max: 99.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 15.19+/- 0.9285 (max: 21.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 73.85+/- 2.644 (max: 89.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 22.59+/- 0.9647 (max: 31.86) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 27.57+/- 0.8156 (max: 31.86) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 14.03+/- 0.3917 (max: 16.27) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 26.18+/- 0.7295 (max: 31.21) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6775+/- 0.05446 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.96+/- 0.00677 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1562+/- 0.01839 (max: 0.3) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9162+/- 0.01615 (max: 0.98) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 8.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 79.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 8.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 51.2 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 11.02 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 20.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 11.02 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 20.91 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.91 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.76 | +------------------------------------------------------------------------------------------------- +Evaluating DR_SoftMoE_SEED2 against population in Overcooked-CoordRing6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [4.799999713897705, 6.0, 17.19999885559082, 16.399999618530273, 20.799999237060547, 20.799999237060547, 17.600000381469727, 15.399999618530273, 33.599998474121094, 34.599998474121094, 17.0, 20.600000381469727, 16.19999885559082, 12.799999237060547, 22.0, 25.799999237060547, 18.19999885559082, 20.600000381469727, 14.399999618530273, 10.800000190734863, 19.0, 20.600000381469727, 17.799999237060547, 19.19999885559082, 18.600000381469727, 15.59999942779541, 23.399999618530273, 25.19999885559082, 19.399999618530273, 19.799999237060547, 13.59999942779541, 12.199999809265137, 29.399999618530273, 24.799999237060547, 36.599998474121094, 38.20000076293945, 5.0, 6.0, 18.600000381469727, 20.0, 27.399999618530273, 29.399999618530273, 8.399999618530273, 10.0, 15.199999809265137, 15.399999618530273, 17.799999237060547, 19.600000381469727] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [4.799999713897705, 6.0, 17.600000381469727, 15.399999618530273, 16.19999885559082, 12.799999237060547, 14.399999618530273, 10.800000190734863, 18.600000381469727, 15.59999942779541, 13.59999942779541, 12.199999809265137, 5.0, 6.0, 8.399999618530273, 10.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [17.19999885559082, 16.399999618530273, 33.599998474121094, 34.599998474121094, 22.0, 25.799999237060547, 19.0, 20.600000381469727, 23.399999618530273, 25.19999885559082, 29.399999618530273, 24.799999237060547, 18.600000381469727, 20.0, 15.199999809265137, 15.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [20.799999237060547, 20.799999237060547, 17.0, 20.600000381469727, 18.19999885559082, 20.600000381469727, 17.799999237060547, 19.19999885559082, 19.399999618530273, 19.799999237060547, 36.599998474121094, 38.20000076293945, 27.399999618530273, 29.399999618530273, 17.799999237060547, 19.600000381469727] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 19.0+/- 1.113 (max: 38.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 22.7+/- 1.656 (max: 38.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 11.71+/- 1.147 (max: 18.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 22.57+/- 1.511 (max: 34.6) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 12.18+/- 0.2692 (max: 16.76) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.08+/- 0.3451 (max: 14.45) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.18+/- 0.5189 (max: 15.82) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 13.29+/- 0.3601 (max: 16.76) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.2125+/- 0.02894 (max: 0.81) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.2487+/- 0.06103 (max: 0.81) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.08875+/- 0.017 (max: 0.24) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.3+/- 0.04734 (max: 0.69) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 4.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 17.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 4.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 15.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 8.542 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 8.716 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 8.542 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.7 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.06 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.08 | +----------------------------------------------------------------------------------------------- +Evaluating DR_SoftMoE_SEED2 against population in Overcooked-ForcedCoord6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [2.799999952316284, 0.0, 3.3999998569488525, 0.0, 2.3999998569488525, 0.3999999761581421, 2.3999998569488525, 0.0, 2.799999952316284, 0.0, 2.200000047683716, 0.0, 2.5999999046325684, 0.0, 3.3999998569488525, 0.0, 2.0, 0.19999998807907104, 3.1999998092651367, 0.0, 2.799999952316284, 0.0, 1.7999999523162842, 0.0, 2.3999998569488525, 0.19999998807907104, 3.5999999046325684, 0.3999999761581421, 2.3999998569488525, 0.0, 2.3999998569488525, 0.0, 3.1999998092651367, 0.0, 2.3999998569488525, 0.0, 1.399999976158142, 0.0, 2.200000047683716, 0.0, 2.200000047683716, 0.0, 2.0, 0.0, 3.1999998092651367, 0.19999998807907104, 3.5999999046325684, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [2.799999952316284, 0.0, 2.3999998569488525, 0.0, 2.5999999046325684, 0.0, 3.1999998092651367, 0.0, 2.3999998569488525, 0.19999998807907104, 2.3999998569488525, 0.0, 1.399999976158142, 0.0, 2.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [3.3999998569488525, 0.0, 2.799999952316284, 0.0, 3.3999998569488525, 0.0, 2.799999952316284, 0.0, 3.5999999046325684, 0.3999999761581421, 3.1999998092651367, 0.0, 2.200000047683716, 0.0, 3.1999998092651367, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [2.3999998569488525, 0.3999999761581421, 2.200000047683716, 0.0, 2.0, 0.19999998807907104, 1.7999999523162842, 0.0, 2.3999998569488525, 0.0, 2.3999998569488525, 0.0, 2.200000047683716, 0.0, 3.5999999046325684, 0.0] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.337+/- 0.1963 (max: 3.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.225+/- 0.3119 (max: 3.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.212+/- 0.3201 (max: 3.2) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 1.575+/- 0.3958 (max: 3.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 3.585+/- 0.468 (max: 7.684) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 3.513+/- 0.7831 (max: 7.684) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 3.349+/- 0.8175 (max: 7.332) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 3.894+/- 0.8759 (max: 7.684) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating DR_SoftMoE_SEED2 against population in Overcooked-CounterCircuit6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [1.7999999523162842, 1.5999999046325684, 3.3999998569488525, 1.5999999046325684, 1.399999976158142, 0.5999999642372131, 1.0, 1.1999999284744263, 2.3999998569488525, 1.399999976158142, 0.19999998807907104, 0.0, 3.3999998569488525, 1.0, 10.399999618530273, 5.0, 6.0, 1.7999999523162842, 1.399999976158142, 0.7999999523162842, 10.399999618530273, 8.59999942779541, 2.3999998569488525, 0.0, 2.3999998569488525, 1.7999999523162842, 9.0, 4.599999904632568, 4.199999809265137, 0.5999999642372131, 3.0, 0.0, 5.400000095367432, 1.399999976158142, 1.7999999523162842, 0.3999999761581421, 1.399999976158142, 1.0, 4.0, 1.7999999523162842, 0.19999998807907104, 0.19999998807907104, 2.5999999046325684, 0.5999999642372131, 4.199999809265137, 2.200000047683716, 0.7999999523162842, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [1.7999999523162842, 1.5999999046325684, 1.0, 1.1999999284744263, 3.3999998569488525, 1.0, 1.399999976158142, 0.7999999523162842, 2.3999998569488525, 1.7999999523162842, 3.0, 0.0, 1.399999976158142, 1.0, 2.5999999046325684, 0.5999999642372131] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [3.3999998569488525, 1.5999999046325684, 2.3999998569488525, 1.399999976158142, 10.399999618530273, 5.0, 10.399999618530273, 8.59999942779541, 9.0, 4.599999904632568, 5.400000095367432, 1.399999976158142, 4.0, 1.7999999523162842, 4.199999809265137, 2.200000047683716] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [1.399999976158142, 0.5999999642372131, 0.19999998807907104, 0.0, 6.0, 1.7999999523162842, 2.3999998569488525, 0.0, 4.199999809265137, 0.5999999642372131, 1.7999999523162842, 0.3999999761581421, 0.19999998807907104, 0.19999998807907104, 0.7999999523162842, 0.0] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 2.529+/- 0.3802 (max: 10.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 1.287+/- 0.4231 (max: 6.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 1.562+/- 0.2275 (max: 3.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 4.737+/- 0.7985 (max: 10.4) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 5.618+/- 0.421 (max: 11.79) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 3.794+/- 0.7308 (max: 10.0) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 5.007+/- 0.4442 (max: 7.513) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 8.053+/- 0.541 (max: 11.79) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.003125+/- 0.001269 (max: 0.05) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.00125+/- 0.00125 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.008125+/- 0.003319 (max: 0.05) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 1.4 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 5.103 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating DR_SoftMoE_SEED2 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 19.19999885559082, 0.0, 37.599998474121094, 0.0, 31.399999618530273, 0.0, 18.0, 0.0, 56.19999694824219, 0.0, 33.599998474121094, 0.0, 56.19999694824219, 0.0, 79.79999542236328, 0.0, 41.79999923706055, 0.0, 37.39999771118164, 0.0, 38.0, 0.0, 37.39999771118164, 0.0, 18.19999885559082, 0.0, 43.39999771118164, 0.0, 37.20000076293945, 0.0, 14.799999237060547, 0.0, 48.20000076293945, 0.0, 32.599998474121094, 0.0, 17.19999885559082, 0.0, 54.79999923706055, 0.0, 37.599998474121094, 0.0, 21.19999885559082, 0.0, 50.599998474121094, 0.0, 31.799999237060547] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 19.19999885559082, 0.0, 18.0, 0.0, 56.19999694824219, 0.0, 37.39999771118164, 0.0, 18.19999885559082, 0.0, 14.799999237060547, 0.0, 17.19999885559082, 0.0, 21.19999885559082] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 37.599998474121094, 0.0, 56.19999694824219, 0.0, 79.79999542236328, 0.0, 38.0, 0.0, 43.39999771118164, 0.0, 48.20000076293945, 0.0, 54.79999923706055, 0.0, 50.599998474121094] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 31.399999618530273, 0.0, 33.599998474121094, 0.0, 41.79999923706055, 0.0, 37.39999771118164, 0.0, 37.20000076293945, 0.0, 32.599998474121094, 0.0, 37.599998474121094, 0.0, 31.799999237060547] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 18.63+/- 3.135 (max: 79.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 17.71+/- 4.616 (max: 41.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 12.64+/- 4.077 (max: 56.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 25.54+/- 6.988 (max: 79.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 11.13+/- 1.665 (max: 29.61) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 11.17+/- 2.889 (max: 23.61) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 9.704+/- 2.614 (max: 28.38) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 12.52+/- 3.26 (max: 29.61) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.2821+/- 0.04722 (max: 0.96) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.29+/- 0.076 (max: 0.69) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.1737+/- 0.05973 (max: 0.81) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.3825+/- 0.1008 (max: 0.96) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating DR_SoftMoE_SEED2 against population in Overcooked-CrampedRoom6_9 for xpid SEED_2_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [8.399999618530273, 7.0, 69.0, 65.0, 102.19999694824219, 99.79999542236328, 14.399999618530273, 14.0, 72.4000015258789, 68.79999542236328, 107.0, 109.0, 18.19999885559082, 18.600000381469727, 90.5999984741211, 91.0, 107.19999694824219, 108.5999984741211, 16.0, 17.600000381469727, 82.79999542236328, 85.79999542236328, 86.79999542236328, 86.19999694824219, 8.399999618530273, 7.799999713897705, 69.4000015258789, 69.0, 95.4000015258789, 94.19999694824219, 20.0, 20.19999885559082, 91.0, 80.5999984741211, 103.79999542236328, 102.5999984741211, 16.0, 16.799999237060547, 75.4000015258789, 75.0, 102.0, 103.19999694824219, 19.600000381469727, 15.59999942779541, 83.19999694824219, 82.0, 103.5999984741211, 103.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [8.399999618530273, 7.0, 14.399999618530273, 14.0, 18.19999885559082, 18.600000381469727, 16.0, 17.600000381469727, 8.399999618530273, 7.799999713897705, 20.0, 20.19999885559082, 16.0, 16.799999237060547, 19.600000381469727, 15.59999942779541] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [69.0, 65.0, 72.4000015258789, 68.79999542236328, 90.5999984741211, 91.0, 82.79999542236328, 85.79999542236328, 69.4000015258789, 69.0, 91.0, 80.5999984741211, 75.4000015258789, 75.0, 83.19999694824219, 82.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [102.19999694824219, 99.79999542236328, 107.0, 109.0, 107.19999694824219, 108.5999984741211, 86.79999542236328, 86.19999694824219, 95.4000015258789, 94.19999694824219, 103.79999542236328, 102.5999984741211, 102.0, 103.19999694824219, 103.5999984741211, 103.0] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 64.67+/- 5.399 (max: 109.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 100.9+/- 1.737 (max: 109.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 14.91+/- 1.142 (max: 20.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 78.19+/- 2.205 (max: 91.0) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 19.09+/- 0.7209 (max: 28.33) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 20.97+/- 0.6191 (max: 25.88) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.14+/- 0.341 (max: 15.56) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 23.15+/- 0.8303 (max: 28.33) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6927+/- 0.05804 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9937+/- 0.002213 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1337+/- 0.01765 (max: 0.24) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9506+/- 0.006675 (max: 0.99) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 7.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 86.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 7.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 65.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.34 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 16.42 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.34 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 17.52 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.97 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9 | +-------------------------------------------------------------------------------------------------- +Evaluating DR_SoftMoE_SEED3 against population in Overcooked-CoordRing6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [14.199999809265137, 11.399999618530273, 34.599998474121094, 30.399999618530273, 25.399999618530273, 27.0, 25.799999237060547, 23.799999237060547, 39.20000076293945, 37.79999923706055, 28.19999885559082, 29.799999237060547, 27.799999237060547, 24.19999885559082, 40.79999923706055, 36.79999923706055, 21.0, 26.0, 28.399999618530273, 29.19999885559082, 30.599998474121094, 34.79999923706055, 25.0, 24.799999237060547, 24.799999237060547, 24.0, 33.79999923706055, 29.399999618530273, 27.399999618530273, 27.799999237060547, 23.799999237060547, 25.0, 34.39999771118164, 38.39999771118164, 38.79999923706055, 38.39999771118164, 8.800000190734863, 9.800000190734863, 27.599998474121094, 28.19999885559082, 25.399999618530273, 32.79999923706055, 22.19999885559082, 23.19999885559082, 27.399999618530273, 24.399999618530273, 24.600000381469727, 24.799999237060547] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [14.199999809265137, 11.399999618530273, 25.799999237060547, 23.799999237060547, 27.799999237060547, 24.19999885559082, 28.399999618530273, 29.19999885559082, 24.799999237060547, 24.0, 23.799999237060547, 25.0, 8.800000190734863, 9.800000190734863, 22.19999885559082, 23.19999885559082] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [34.599998474121094, 30.399999618530273, 39.20000076293945, 37.79999923706055, 40.79999923706055, 36.79999923706055, 30.599998474121094, 34.79999923706055, 33.79999923706055, 29.399999618530273, 34.39999771118164, 38.39999771118164, 27.599998474121094, 28.19999885559082, 27.399999618530273, 24.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [25.399999618530273, 27.0, 28.19999885559082, 29.799999237060547, 21.0, 26.0, 25.0, 24.799999237060547, 27.399999618530273, 27.799999237060547, 38.79999923706055, 38.39999771118164, 25.399999618530273, 32.79999923706055, 24.600000381469727, 24.799999237060547] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 27.55+/- 1.037 (max: 40.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 27.95+/- 1.225 (max: 38.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 21.65+/- 1.669 (max: 29.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 33.04+/- 1.22 (max: 40.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 17.07+/- 0.334 (max: 22.03) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 18.27+/- 0.5251 (max: 21.69) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 15.29+/- 0.4934 (max: 17.64) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 17.65+/- 0.4483 (max: 22.03) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.4375+/- 0.02383 (max: 0.73) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.45+/- 0.02714 (max: 0.67) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.2987+/- 0.03773 (max: 0.5) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.5637+/- 0.02863 (max: 0.73) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 8.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 21.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 8.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 24.4 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 11.07 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 14.94 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 11.07 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 14.6 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.24 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.03 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.34 | +----------------------------------------------------------------------------------------------- +Evaluating DR_SoftMoE_SEED3 against population in Overcooked-ForcedCoord6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.0, 1.0, 0.19999998807907104, 1.399999976158142, 0.0, 1.399999976158142, 0.19999998807907104, 1.1999999284744263, 0.0, 0.5999999642372131, 0.5999999642372131, 2.200000047683716, 1.0, 1.5999999046325684, 1.5999999046325684, 2.0, 0.19999998807907104, 0.7999999523162842, 0.0, 1.0, 0.3999999761581421, 1.1999999284744263, 0.0, 1.0, 0.3999999761581421, 2.200000047683716, 0.3999999761581421, 0.7999999523162842, 0.19999998807907104, 2.799999952316284, 0.0, 3.0, 0.5999999642372131, 3.0, 0.0, 3.5999999046325684, 0.5999999642372131, 1.1999999284744263, 0.19999998807907104, 2.200000047683716, 0.0, 0.7999999523162842, 0.19999998807907104, 0.3999999761581421, 0.19999998807907104, 0.7999999523162842, 0.5999999642372131, 4.799999713897705] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 1.0, 0.19999998807907104, 1.1999999284744263, 1.0, 1.5999999046325684, 0.0, 1.0, 0.3999999761581421, 2.200000047683716, 0.0, 3.0, 0.5999999642372131, 1.1999999284744263, 0.19999998807907104, 0.3999999761581421] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.19999998807907104, 1.399999976158142, 0.0, 0.5999999642372131, 1.5999999046325684, 2.0, 0.3999999761581421, 1.1999999284744263, 0.3999999761581421, 0.7999999523162842, 0.5999999642372131, 3.0, 0.19999998807907104, 2.200000047683716, 0.19999998807907104, 0.7999999523162842] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 1.399999976158142, 0.5999999642372131, 2.200000047683716, 0.19999998807907104, 0.7999999523162842, 0.0, 1.0, 0.19999998807907104, 2.799999952316284, 0.0, 3.5999999046325684, 0.0, 0.7999999523162842, 0.5999999642372131, 4.799999713897705] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.012+/- 0.1544 (max: 4.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.187+/- 0.3603 (max: 4.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.875+/- 0.212 (max: 3.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.975+/- 0.2144 (max: 3.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 3.61+/- 0.3338 (max: 8.998) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 3.624+/- 0.7192 (max: 8.998) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 3.4+/- 0.5479 (max: 7.141) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 3.807+/- 0.4738 (max: 7.141) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0002083+/- 0.0002083 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating DR_SoftMoE_SEED3 against population in Overcooked-CounterCircuit6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [14.799999237060547, 14.59999942779541, 25.599998474121094, 18.19999885559082, 11.199999809265137, 10.0, 10.0, 8.399999618530273, 15.399999618530273, 11.399999618530273, 3.5999999046325684, 2.0, 13.0, 12.799999237060547, 16.0, 14.199999809265137, 14.799999237060547, 19.799999237060547, 10.0, 7.599999904632568, 20.799999237060547, 17.19999885559082, 8.199999809265137, 7.799999713897705, 11.399999618530273, 7.599999904632568, 24.19999885559082, 17.600000381469727, 4.199999809265137, 2.799999952316284, 13.399999618530273, 13.59999942779541, 19.399999618530273, 14.199999809265137, 5.799999713897705, 4.799999713897705, 8.800000190734863, 6.199999809265137, 22.799999237060547, 12.59999942779541, 9.399999618530273, 7.799999713897705, 15.59999942779541, 10.399999618530273, 33.39999771118164, 19.399999618530273, 10.399999618530273, 8.199999809265137] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [14.799999237060547, 14.59999942779541, 10.0, 8.399999618530273, 13.0, 12.799999237060547, 10.0, 7.599999904632568, 11.399999618530273, 7.599999904632568, 13.399999618530273, 13.59999942779541, 8.800000190734863, 6.199999809265137, 15.59999942779541, 10.399999618530273] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [25.599998474121094, 18.19999885559082, 15.399999618530273, 11.399999618530273, 16.0, 14.199999809265137, 20.799999237060547, 17.19999885559082, 24.19999885559082, 17.600000381469727, 19.399999618530273, 14.199999809265137, 22.799999237060547, 12.59999942779541, 33.39999771118164, 19.399999618530273] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [11.199999809265137, 10.0, 3.5999999046325684, 2.0, 14.799999237060547, 19.799999237060547, 8.199999809265137, 7.799999713897705, 4.199999809265137, 2.799999952316284, 5.799999713897705, 4.799999713897705, 9.399999618530273, 7.799999713897705, 10.399999618530273, 8.199999809265137] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 12.74+/- 0.9171 (max: 33.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 8.175+/- 1.153 (max: 19.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 11.14+/- 0.7311 (max: 15.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 18.9+/- 1.397 (max: 33.4) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 12.3+/- 0.4529 (max: 21.92) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 10.07+/- 0.531 (max: 13.71) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 11.43+/- 0.3169 (max: 13.52) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 15.4+/- 0.7245 (max: 21.92) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.1054+/- 0.01784 (max: 0.53) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.035+/- 0.01354 (max: 0.21) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.05062+/- 0.01105 (max: 0.13) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.2306+/- 0.03332 (max: 0.53) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 2.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 2.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 6.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 11.4 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 6.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 6.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 9.708 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 10.58 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.06 | +----------------------------------------------------------------------------------------------------- +Evaluating DR_SoftMoE_SEED3 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [7.599999904632568, 13.399999618530273, 0.5999999642372131, 23.399999618530273, 19.600000381469727, 18.0, 10.399999618530273, 14.799999237060547, 7.799999713897705, 47.599998474121094, 10.0, 12.399999618530273, 15.59999942779541, 45.0, 11.0, 58.19999694824219, 16.600000381469727, 26.399999618530273, 4.0, 25.799999237060547, 3.799999952316284, 17.799999237060547, 18.600000381469727, 13.199999809265137, 8.199999809265137, 11.399999618530273, 8.399999618530273, 28.19999885559082, 4.400000095367432, 18.799999237060547, 9.59999942779541, 14.199999809265137, 4.599999904632568, 36.20000076293945, 16.600000381469727, 9.800000190734863, 7.399999618530273, 11.59999942779541, 4.599999904632568, 25.399999618530273, 20.0, 17.799999237060547, 8.59999942779541, 10.800000190734863, 1.1999999284744263, 35.39999771118164, 15.399999618530273, 12.799999237060547] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [7.599999904632568, 13.399999618530273, 10.399999618530273, 14.799999237060547, 15.59999942779541, 45.0, 4.0, 25.799999237060547, 8.199999809265137, 11.399999618530273, 9.59999942779541, 14.199999809265137, 7.399999618530273, 11.59999942779541, 8.59999942779541, 10.800000190734863] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.5999999642372131, 23.399999618530273, 7.799999713897705, 47.599998474121094, 11.0, 58.19999694824219, 3.799999952316284, 17.799999237060547, 8.399999618530273, 28.19999885559082, 4.599999904632568, 36.20000076293945, 4.599999904632568, 25.399999618530273, 1.1999999284744263, 35.39999771118164] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [19.600000381469727, 18.0, 10.0, 12.399999618530273, 16.600000381469727, 26.399999618530273, 18.600000381469727, 13.199999809265137, 4.400000095367432, 18.799999237060547, 16.600000381469727, 9.800000190734863, 20.0, 17.799999237060547, 15.399999618530273, 12.799999237060547] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 16.31+/- 1.731 (max: 58.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 15.65+/- 1.288 (max: 26.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 13.65+/- 2.419 (max: 45.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 19.64+/- 4.408 (max: 58.2) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 14.57+/- 0.7185 (max: 25.98) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 15.13+/- 0.6393 (max: 19.77) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 13.78+/- 1.065 (max: 25.98) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 14.8+/- 1.801 (max: 25.78) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.1917+/- 0.0299 (max: 0.83) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.1862+/- 0.02541 (max: 0.42) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.1306+/- 0.04228 (max: 0.67) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.2581+/- 0.07367 (max: 0.83) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.6 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 4.4 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 4.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.6 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 3.412 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 9.2 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 8.485 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 3.412 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.01 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating DR_SoftMoE_SEED3 against population in Overcooked-CrampedRoom6_9 for xpid SEED_3_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [24.600000381469727, 20.0, 81.79999542236328, 71.79999542236328, 97.19999694824219, 91.19999694824219, 28.799999237060547, 29.399999618530273, 83.0, 74.5999984741211, 91.0, 87.4000015258789, 31.599998474121094, 27.599998474121094, 103.5999984741211, 95.19999694824219, 96.0, 95.79999542236328, 29.399999618530273, 33.0, 93.79999542236328, 88.19999694824219, 87.19999694824219, 83.4000015258789, 21.600000381469727, 18.799999237060547, 82.19999694824219, 80.0, 104.0, 105.39999389648438, 33.599998474121094, 28.799999237060547, 95.5999984741211, 97.19999694824219, 94.19999694824219, 84.5999984741211, 31.599998474121094, 34.0, 90.79999542236328, 88.5999984741211, 103.79999542236328, 101.19999694824219, 33.39999771118164, 34.79999923706055, 89.79999542236328, 93.0, 92.5999984741211, 93.5999984741211] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [24.600000381469727, 20.0, 28.799999237060547, 29.399999618530273, 31.599998474121094, 27.599998474121094, 29.399999618530273, 33.0, 21.600000381469727, 18.799999237060547, 33.599998474121094, 28.799999237060547, 31.599998474121094, 34.0, 33.39999771118164, 34.79999923706055] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [81.79999542236328, 71.79999542236328, 83.0, 74.5999984741211, 103.5999984741211, 95.19999694824219, 93.79999542236328, 88.19999694824219, 82.19999694824219, 80.0, 95.5999984741211, 97.19999694824219, 90.79999542236328, 88.5999984741211, 89.79999542236328, 93.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [97.19999694824219, 91.19999694824219, 91.0, 87.4000015258789, 96.0, 95.79999542236328, 87.19999694824219, 83.4000015258789, 104.0, 105.39999389648438, 94.19999694824219, 84.5999984741211, 103.79999542236328, 101.19999694824219, 92.5999984741211, 93.5999984741211] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 70.39+/- 4.417 (max: 105.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 94.29+/- 1.712 (max: 105.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 28.81+/- 1.274 (max: 34.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 88.07+/- 2.151 (max: 103.6) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 27.71+/- 0.9204 (max: 43.02) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 33.13+/- 0.8861 (max: 43.02) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 19.64+/- 0.4019 (max: 22.48) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 30.35+/- 0.5063 (max: 34.45) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.7933+/- 0.03521 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9619+/- 0.00754 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.4637+/- 0.02647 (max: 0.6) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9544+/- 0.00555 (max: 0.99) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 18.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 83.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 18.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 71.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 17.52 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 27.86 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 17.52 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 27.44 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.25 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.89 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.25 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9 | +------------------------------------------------------------------------------------------------- +Evaluating PLR_SoftMoE_SEED1 against population in Overcooked-CoordRing6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [7.199999809265137, 7.199999809265137, 19.0, 18.600000381469727, 27.399999618530273, 22.399999618530273, 15.59999942779541, 17.0, 32.0, 32.0, 25.0, 24.0, 14.59999942779541, 14.0, 25.799999237060547, 27.19999885559082, 26.599998474121094, 27.19999885559082, 15.0, 14.399999618530273, 17.799999237060547, 17.399999618530273, 22.19999885559082, 24.0, 14.799999237060547, 14.199999809265137, 21.19999885559082, 22.0, 27.799999237060547, 28.399999618530273, 17.600000381469727, 19.600000381469727, 26.799999237060547, 31.599998474121094, 40.599998474121094, 34.20000076293945, 4.0, 5.799999713897705, 20.399999618530273, 19.399999618530273, 24.600000381469727, 28.19999885559082, 14.799999237060547, 13.799999237060547, 15.799999237060547, 15.199999809265137, 12.199999809265137, 16.600000381469727] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [7.199999809265137, 7.199999809265137, 15.59999942779541, 17.0, 14.59999942779541, 14.0, 15.0, 14.399999618530273, 14.799999237060547, 14.199999809265137, 17.600000381469727, 19.600000381469727, 4.0, 5.799999713897705, 14.799999237060547, 13.799999237060547] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [19.0, 18.600000381469727, 32.0, 32.0, 25.799999237060547, 27.19999885559082, 17.799999237060547, 17.399999618530273, 21.19999885559082, 22.0, 26.799999237060547, 31.599998474121094, 20.399999618530273, 19.399999618530273, 15.799999237060547, 15.199999809265137] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [27.399999618530273, 22.399999618530273, 25.0, 24.0, 26.599998474121094, 27.19999885559082, 22.19999885559082, 24.0, 27.799999237060547, 28.399999618530273, 40.599998474121094, 34.20000076293945, 24.600000381469727, 28.19999885559082, 12.199999809265137, 16.600000381469727] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 20.48+/- 1.116 (max: 40.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 25.71+/- 1.599 (max: 40.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 13.1+/- 1.128 (max: 19.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 22.64+/- 1.449 (max: 32.0) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 13.09+/- 0.2731 (max: 17.04) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 12.52+/- 0.4405 (max: 17.04) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.86+/- 0.5939 (max: 16.34) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 13.89+/- 0.2947 (max: 16.55) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.265+/- 0.02739 (max: 0.85) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.3687+/- 0.04959 (max: 0.85) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.1169+/- 0.01934 (max: 0.26) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.3094+/- 0.04365 (max: 0.62) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 4.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 12.2 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 4.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 15.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 8.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 10.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 8.0 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 12.27 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.1 | +----------------------------------------------------------------------------------------------- +Evaluating PLR_SoftMoE_SEED1 against population in Overcooked-ForcedCoord6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [4.400000095367432, 0.19999998807907104, 5.400000095367432, 0.19999998807907104, 5.599999904632568, 0.5999999642372131, 1.7999999523162842, 0.0, 2.5999999046325684, 0.19999998807907104, 3.0, 0.3999999761581421, 3.0, 0.19999998807907104, 4.599999904632568, 0.19999998807907104, 4.400000095367432, 0.3999999761581421, 3.1999998092651367, 0.5999999642372131, 3.5999999046325684, 0.3999999761581421, 4.400000095367432, 0.19999998807907104, 4.0, 0.3999999761581421, 4.0, 0.19999998807907104, 3.3999998569488525, 0.5999999642372131, 3.3999998569488525, 0.0, 2.5999999046325684, 0.3999999761581421, 2.5999999046325684, 0.0, 1.1999999284744263, 0.0, 2.5999999046325684, 0.3999999761581421, 2.200000047683716, 0.0, 3.1999998092651367, 0.0, 6.0, 0.0, 5.0, 0.3999999761581421] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [4.400000095367432, 0.19999998807907104, 1.7999999523162842, 0.0, 3.0, 0.19999998807907104, 3.1999998092651367, 0.5999999642372131, 4.0, 0.3999999761581421, 3.3999998569488525, 0.0, 1.1999999284744263, 0.0, 3.1999998092651367, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [5.400000095367432, 0.19999998807907104, 2.5999999046325684, 0.19999998807907104, 4.599999904632568, 0.19999998807907104, 3.5999999046325684, 0.3999999761581421, 4.0, 0.19999998807907104, 2.5999999046325684, 0.3999999761581421, 2.5999999046325684, 0.3999999761581421, 6.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [5.599999904632568, 0.5999999642372131, 3.0, 0.3999999761581421, 4.400000095367432, 0.3999999761581421, 4.400000095367432, 0.19999998807907104, 3.3999998569488525, 0.5999999642372131, 2.5999999046325684, 0.0, 2.200000047683716, 0.0, 5.0, 0.3999999761581421] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.921+/- 0.2739 (max: 6.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 2.075+/- 0.4983 (max: 5.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.6+/- 0.4119 (max: 4.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 2.087+/- 0.5263 (max: 6.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 4.706+/- 0.4548 (max: 9.539) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 5.028+/- 0.8035 (max: 9.539) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 4.171+/- 0.8164 (max: 8.754) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 4.918+/- 0.7769 (max: 9.165) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0008333+/- 0.0005012 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.001875+/- 0.00136 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PLR_SoftMoE_SEED1 against population in Overcooked-CounterCircuit6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [5.599999904632568, 3.1999998092651367, 8.0, 5.199999809265137, 5.400000095367432, 1.7999999523162842, 2.200000047683716, 2.5999999046325684, 5.0, 3.0, 1.399999976158142, 0.0, 3.1999998092651367, 1.7999999523162842, 11.399999618530273, 7.199999809265137, 5.0, 3.3999998569488525, 4.599999904632568, 1.7999999523162842, 13.799999237060547, 9.399999618530273, 6.0, 1.5999999046325684, 3.0, 3.0, 11.399999618530273, 7.199999809265137, 4.199999809265137, 2.200000047683716, 4.400000095367432, 2.200000047683716, 9.399999618530273, 4.599999904632568, 3.1999998092651367, 2.200000047683716, 2.3999998569488525, 2.5999999046325684, 5.799999713897705, 4.199999809265137, 3.0, 0.5999999642372131, 5.400000095367432, 3.3999998569488525, 7.399999618530273, 3.1999998092651367, 4.199999809265137, 1.399999976158142] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [5.599999904632568, 3.1999998092651367, 2.200000047683716, 2.5999999046325684, 3.1999998092651367, 1.7999999523162842, 4.599999904632568, 1.7999999523162842, 3.0, 3.0, 4.400000095367432, 2.200000047683716, 2.3999998569488525, 2.5999999046325684, 5.400000095367432, 3.3999998569488525] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [8.0, 5.199999809265137, 5.0, 3.0, 11.399999618530273, 7.199999809265137, 13.799999237060547, 9.399999618530273, 11.399999618530273, 7.199999809265137, 9.399999618530273, 4.599999904632568, 5.799999713897705, 4.199999809265137, 7.399999618530273, 3.1999998092651367] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [5.400000095367432, 1.7999999523162842, 1.399999976158142, 0.0, 5.0, 3.3999998569488525, 6.0, 1.5999999046325684, 4.199999809265137, 2.200000047683716, 3.1999998092651367, 2.200000047683716, 3.0, 0.5999999642372131, 4.199999809265137, 1.399999976158142] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 4.442+/- 0.4257 (max: 13.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 2.85+/- 0.438 (max: 6.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 3.212+/- 0.2986 (max: 5.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 7.262+/- 0.7888 (max: 13.8) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 8.003+/- 0.3485 (max: 12.85) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 6.449+/- 0.6006 (max: 9.319) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 7.313+/- 0.2724 (max: 9.415) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 10.25+/- 0.4224 (max: 12.85) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.01146+/- 0.003234 (max: 0.09) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0025+/- 0.001118 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.03062+/- 0.007717 (max: 0.09) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 1.8 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 3.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 5.724 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 7.681 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating PLR_SoftMoE_SEED1 against population in Overcooked-AsymmAdvantages6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.5999999642372131, 6.0, 0.0, 16.19999885559082, 0.5999999642372131, 6.599999904632568, 0.19999998807907104, 9.800000190734863, 0.0, 45.599998474121094, 0.3999999761581421, 6.599999904632568, 0.0, 41.599998474121094, 0.0, 61.79999923706055, 0.19999998807907104, 15.0, 0.0, 19.799999237060547, 0.0, 15.199999809265137, 0.19999998807907104, 9.0, 0.3999999761581421, 5.400000095367432, 0.0, 23.19999885559082, 0.0, 7.599999904632568, 0.5999999642372131, 7.599999904632568, 0.0, 28.399999618530273, 0.3999999761581421, 6.599999904632568, 0.19999998807907104, 5.400000095367432, 0.0, 22.0, 0.7999999523162842, 12.0, 0.19999998807907104, 7.399999618530273, 0.0, 24.399999618530273, 0.7999999523162842, 6.199999809265137] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.5999999642372131, 6.0, 0.19999998807907104, 9.800000190734863, 0.0, 41.599998474121094, 0.0, 19.799999237060547, 0.3999999761581421, 5.400000095367432, 0.5999999642372131, 7.599999904632568, 0.19999998807907104, 5.400000095367432, 0.19999998807907104, 7.399999618530273] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 16.19999885559082, 0.0, 45.599998474121094, 0.0, 61.79999923706055, 0.0, 15.199999809265137, 0.0, 23.19999885559082, 0.0, 28.399999618530273, 0.0, 22.0, 0.0, 24.399999618530273] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.5999999642372131, 6.599999904632568, 0.3999999761581421, 6.599999904632568, 0.19999998807907104, 15.0, 0.19999998807907104, 9.0, 0.0, 7.599999904632568, 0.3999999761581421, 6.599999904632568, 0.7999999523162842, 12.0, 0.7999999523162842, 6.199999809265137] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 8.646+/- 1.922 (max: 61.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 4.562+/- 1.2 (max: 15.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 6.575+/- 2.689 (max: 41.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 14.8+/- 4.704 (max: 61.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 8.702+/- 1.228 (max: 30.74) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 7.281+/- 1.246 (max: 13.56) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 7.979+/- 1.918 (max: 28.24) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 10.85+/- 2.908 (max: 30.74) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.1025+/- 0.02856 (max: 0.85) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.02937+/- 0.009809 (max: 0.13) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.06812+/- 0.03898 (max: 0.59) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.21+/- 0.06963 (max: 0.85) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating PLR_SoftMoE_SEED1 against population in Overcooked-CrampedRoom6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [17.0, 17.600000381469727, 64.19999694824219, 64.0, 75.0, 75.0, 21.399999618530273, 25.0, 61.599998474121094, 53.79999923706055, 81.79999542236328, 83.4000015258789, 19.600000381469727, 21.799999237060547, 73.4000015258789, 70.0, 76.19999694824219, 70.5999984741211, 23.0, 23.19999885559082, 62.0, 64.5999984741211, 56.19999694824219, 57.599998474121094, 12.199999809265137, 16.0, 60.599998474121094, 56.79999923706055, 85.4000015258789, 79.79999542236328, 24.600000381469727, 24.399999618530273, 63.79999923706055, 60.0, 80.4000015258789, 82.4000015258789, 24.600000381469727, 26.399999618530273, 74.19999694824219, 68.4000015258789, 84.4000015258789, 84.79999542236328, 26.0, 25.599998474121094, 62.39999771118164, 63.599998474121094, 64.79999542236328, 73.4000015258789] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [17.0, 17.600000381469727, 21.399999618530273, 25.0, 19.600000381469727, 21.799999237060547, 23.0, 23.19999885559082, 12.199999809265137, 16.0, 24.600000381469727, 24.399999618530273, 24.600000381469727, 26.399999618530273, 26.0, 25.599998474121094] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [64.19999694824219, 64.0, 61.599998474121094, 53.79999923706055, 73.4000015258789, 70.0, 62.0, 64.5999984741211, 60.599998474121094, 56.79999923706055, 63.79999923706055, 60.0, 74.19999694824219, 68.4000015258789, 62.39999771118164, 63.599998474121094] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [75.0, 75.0, 81.79999542236328, 83.4000015258789, 76.19999694824219, 70.5999984741211, 56.19999694824219, 57.599998474121094, 85.4000015258789, 79.79999542236328, 80.4000015258789, 82.4000015258789, 84.4000015258789, 84.79999542236328, 64.79999542236328, 73.4000015258789] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 53.81+/- 3.506 (max: 85.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 75.7+/- 2.32 (max: 85.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 21.77+/- 1.042 (max: 26.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 63.96+/- 1.36 (max: 74.2) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 29.84+/- 1.2 (max: 43.54) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 36.57+/- 0.8246 (max: 43.54) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 19.0+/- 0.4532 (max: 21.61) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 33.96+/- 0.8097 (max: 39.83) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6675+/- 0.03748 (max: 0.95) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.8631+/- 0.01688 (max: 0.95) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.3169+/- 0.02188 (max: 0.41) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.8225+/- 0.009725 (max: 0.87) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 12.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 56.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 12.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 53.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 14.94 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 29.26 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 14.94 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 28.53 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.14 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.72 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.14 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.74 | +-------------------------------------------------------------------------------------------------- +Evaluating PLR_SoftMoE_SEED2 against population in Overcooked-CoordRing6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [4.799999713897705, 4.400000095367432, 24.0, 21.0, 27.19999885559082, 29.599998474121094, 16.799999237060547, 14.0, 34.599998474121094, 35.79999923706055, 28.399999618530273, 31.0, 14.399999618530273, 15.0, 28.0, 29.19999885559082, 30.19999885559082, 34.599998474121094, 18.799999237060547, 12.0, 20.799999237060547, 21.600000381469727, 25.599998474121094, 25.599998474121094, 18.19999885559082, 14.799999237060547, 26.19999885559082, 24.799999237060547, 29.599998474121094, 30.399999618530273, 19.0, 17.399999618530273, 28.19999885559082, 26.0, 38.599998474121094, 39.0, 6.0, 5.0, 19.799999237060547, 21.399999618530273, 32.79999923706055, 35.79999923706055, 17.19999885559082, 15.399999618530273, 19.799999237060547, 21.399999618530273, 20.799999237060547, 25.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [4.799999713897705, 4.400000095367432, 16.799999237060547, 14.0, 14.399999618530273, 15.0, 18.799999237060547, 12.0, 18.19999885559082, 14.799999237060547, 19.0, 17.399999618530273, 6.0, 5.0, 17.19999885559082, 15.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [24.0, 21.0, 34.599998474121094, 35.79999923706055, 28.0, 29.19999885559082, 20.799999237060547, 21.600000381469727, 26.19999885559082, 24.799999237060547, 28.19999885559082, 26.0, 19.799999237060547, 21.399999618530273, 19.799999237060547, 21.399999618530273] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [27.19999885559082, 29.599998474121094, 28.399999618530273, 31.0, 30.19999885559082, 34.599998474121094, 25.599998474121094, 25.599998474121094, 29.599998474121094, 30.399999618530273, 38.599998474121094, 39.0, 32.79999923706055, 35.79999923706055, 20.799999237060547, 25.399999618530273] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 22.92+/- 1.261 (max: 39.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 30.29+/- 1.246 (max: 39.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 13.32+/- 1.319 (max: 19.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 25.16+/- 1.248 (max: 35.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 13.54+/- 0.3026 (max: 18.22) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 13.71+/- 0.5262 (max: 17.97) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.84+/- 0.5976 (max: 16.2) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 14.06+/- 0.418 (max: 18.22) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.336+/- 0.03009 (max: 0.75) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.5194+/- 0.03608 (max: 0.75) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.1231+/- 0.01947 (max: 0.26) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.3656+/- 0.03884 (max: 0.71) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 4.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 20.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 4.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 19.8 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 8.542 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.09 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 8.542 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.66 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.26 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.19 | +----------------------------------------------------------------------------------------------- +Evaluating PLR_SoftMoE_SEED2 against population in Overcooked-ForcedCoord6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [2.799999952316284, 0.0, 1.5999999046325684, 0.19999998807907104, 0.19999998807907104, 0.5999999642372131, 1.399999976158142, 0.19999998807907104, 0.7999999523162842, 0.3999999761581421, 1.1999999284744263, 0.0, 2.3999998569488525, 0.3999999761581421, 1.7999999523162842, 0.3999999761581421, 2.200000047683716, 0.7999999523162842, 1.7999999523162842, 0.0, 1.5999999046325684, 0.0, 1.0, 0.19999998807907104, 0.5999999642372131, 0.19999998807907104, 1.0, 0.19999998807907104, 1.399999976158142, 0.3999999761581421, 0.7999999523162842, 0.0, 0.5999999642372131, 0.5999999642372131, 0.3999999761581421, 0.19999998807907104, 0.5999999642372131, 0.0, 0.3999999761581421, 0.0, 0.5999999642372131, 0.0, 2.5999999046325684, 0.0, 2.5999999046325684, 0.0, 4.799999713897705, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [2.799999952316284, 0.0, 1.399999976158142, 0.19999998807907104, 2.3999998569488525, 0.3999999761581421, 1.7999999523162842, 0.0, 0.5999999642372131, 0.19999998807907104, 0.7999999523162842, 0.0, 0.5999999642372131, 0.0, 2.5999999046325684, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [1.5999999046325684, 0.19999998807907104, 0.7999999523162842, 0.3999999761581421, 1.7999999523162842, 0.3999999761581421, 1.5999999046325684, 0.0, 1.0, 0.19999998807907104, 0.5999999642372131, 0.5999999642372131, 0.3999999761581421, 0.0, 2.5999999046325684, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.19999998807907104, 0.5999999642372131, 1.1999999284744263, 0.0, 2.200000047683716, 0.7999999523162842, 1.0, 0.19999998807907104, 1.399999976158142, 0.3999999761581421, 0.3999999761581421, 0.19999998807907104, 0.5999999642372131, 0.0, 4.799999713897705, 0.19999998807907104] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.8375+/- 0.1423 (max: 4.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.8875+/- 0.2994 (max: 4.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.8625+/- 0.2521 (max: 2.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.7625+/- 0.1908 (max: 2.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 3.179+/- 0.3258 (max: 8.542) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 3.332+/- 0.5486 (max: 8.542) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 3.032+/- 0.6511 (max: 6.94) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 3.174+/- 0.5195 (max: 6.726) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating PLR_SoftMoE_SEED2 against population in Overcooked-CounterCircuit6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [8.0, 5.599999904632568, 10.59999942779541, 12.199999809265137, 7.199999809265137, 3.1999998092651367, 6.199999809265137, 3.0, 5.599999904632568, 4.799999713897705, 1.7999999523162842, 0.0, 5.0, 3.0, 15.399999618530273, 9.800000190734863, 7.599999904632568, 4.599999904632568, 5.199999809265137, 3.1999998092651367, 18.600000381469727, 16.0, 9.399999618530273, 3.5999999046325684, 5.0, 4.199999809265137, 12.799999237060547, 10.800000190734863, 6.0, 1.5999999046325684, 9.399999618530273, 3.799999952316284, 11.800000190734863, 9.399999618530273, 6.799999713897705, 2.200000047683716, 6.599999904632568, 3.3999998569488525, 8.399999618530273, 8.0, 7.599999904632568, 1.7999999523162842, 8.800000190734863, 5.400000095367432, 11.399999618530273, 10.0, 8.800000190734863, 2.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [8.0, 5.599999904632568, 6.199999809265137, 3.0, 5.0, 3.0, 5.199999809265137, 3.1999998092651367, 5.0, 4.199999809265137, 9.399999618530273, 3.799999952316284, 6.599999904632568, 3.3999998569488525, 8.800000190734863, 5.400000095367432] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [10.59999942779541, 12.199999809265137, 5.599999904632568, 4.799999713897705, 15.399999618530273, 9.800000190734863, 18.600000381469727, 16.0, 12.799999237060547, 10.800000190734863, 11.800000190734863, 9.399999618530273, 8.399999618530273, 8.0, 11.399999618530273, 10.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [7.199999809265137, 3.1999998092651367, 1.7999999523162842, 0.0, 7.599999904632568, 4.599999904632568, 9.399999618530273, 3.5999999046325684, 6.0, 1.5999999046325684, 6.799999713897705, 2.200000047683716, 7.599999904632568, 1.7999999523162842, 8.800000190734863, 2.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 6.992+/- 0.5864 (max: 18.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 4.637+/- 0.7459 (max: 9.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 5.362+/- 0.5047 (max: 9.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 10.97+/- 0.904 (max: 18.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 9.638+/- 0.41 (max: 16.06) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 7.581+/- 0.6792 (max: 10.75) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 9.033+/- 0.3242 (max: 11.43) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 12.3+/- 0.4854 (max: 16.06) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.03042+/- 0.006671 (max: 0.17) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.001875+/- 0.00136 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.01062+/- 0.003091 (max: 0.04) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.07875+/- 0.01307 (max: 0.17) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 3.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 4.8 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 7.141 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 8.998 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PLR_SoftMoE_SEED2 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.3999999761581421, 1.5999999046325684, 0.3999999761581421, 9.800000190734863, 1.1999999284744263, 6.599999904632568, 0.19999998807907104, 6.599999904632568, 0.0, 43.20000076293945, 0.0, 7.599999904632568, 0.3999999761581421, 33.79999923706055, 0.19999998807907104, 55.599998474121094, 0.7999999523162842, 10.800000190734863, 0.19999998807907104, 17.0, 0.3999999761581421, 9.399999618530273, 1.0, 3.1999998092651367, 0.7999999523162842, 2.0, 0.19999998807907104, 19.399999618530273, 0.3999999761581421, 4.599999904632568, 0.5999999642372131, 3.3999998569488525, 0.3999999761581421, 27.799999237060547, 0.0, 5.199999809265137, 0.0, 4.400000095367432, 0.19999998807907104, 18.19999885559082, 0.5999999642372131, 6.599999904632568, 0.19999998807907104, 2.799999952316284, 0.0, 21.799999237060547, 0.7999999523162842, 2.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.3999999761581421, 1.5999999046325684, 0.19999998807907104, 6.599999904632568, 0.3999999761581421, 33.79999923706055, 0.19999998807907104, 17.0, 0.7999999523162842, 2.0, 0.5999999642372131, 3.3999998569488525, 0.0, 4.400000095367432, 0.19999998807907104, 2.799999952316284] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.3999999761581421, 9.800000190734863, 0.0, 43.20000076293945, 0.19999998807907104, 55.599998474121094, 0.3999999761581421, 9.399999618530273, 0.19999998807907104, 19.399999618530273, 0.3999999761581421, 27.799999237060547, 0.19999998807907104, 18.19999885559082, 0.0, 21.799999237060547] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [1.1999999284744263, 6.599999904632568, 0.0, 7.599999904632568, 0.7999999523162842, 10.800000190734863, 1.0, 3.1999998092651367, 0.3999999761581421, 4.599999904632568, 0.0, 5.199999809265137, 0.5999999642372131, 6.599999904632568, 0.7999999523162842, 2.0] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 6.933+/- 1.716 (max: 55.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 3.212+/- 0.8243 (max: 10.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 4.65+/- 2.213 (max: 33.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 12.94+/- 4.292 (max: 55.6) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 8.045+/- 1.111 (max: 31.19) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 6.305+/- 0.9994 (max: 11.93) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 6.692+/- 1.656 (max: 25.56) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 11.14+/- 2.627 (max: 31.19) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.07646+/- 0.0253 (max: 0.77) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.01375+/- 0.005072 (max: 0.06) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.04437+/- 0.03089 (max: 0.46) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1712+/- 0.06408 (max: 0.77) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating PLR_SoftMoE_SEED2 against population in Overcooked-CrampedRoom6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [8.399999618530273, 7.399999618530273, 68.4000015258789, 69.5999984741211, 101.79999542236328, 103.5999984741211, 14.399999618530273, 12.199999809265137, 81.19999694824219, 70.0, 109.79999542236328, 109.19999694824219, 16.19999885559082, 15.199999809265137, 99.19999694824219, 94.79999542236328, 108.79999542236328, 104.0, 19.19999885559082, 17.0, 90.79999542236328, 88.19999694824219, 93.4000015258789, 95.5999984741211, 9.0, 8.199999809265137, 72.5999984741211, 68.5999984741211, 96.19999694824219, 103.79999542236328, 16.399999618530273, 19.0, 94.79999542236328, 86.79999542236328, 109.0, 109.39999389648438, 14.199999809265137, 17.0, 83.0, 80.4000015258789, 108.5999984741211, 107.5999984741211, 21.600000381469727, 17.399999618530273, 91.19999694824219, 89.19999694824219, 111.19999694824219, 109.19999694824219] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [8.399999618530273, 7.399999618530273, 14.399999618530273, 12.199999809265137, 16.19999885559082, 15.199999809265137, 19.19999885559082, 17.0, 9.0, 8.199999809265137, 16.399999618530273, 19.0, 14.199999809265137, 17.0, 21.600000381469727, 17.399999618530273] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [68.4000015258789, 69.5999984741211, 81.19999694824219, 70.0, 99.19999694824219, 94.79999542236328, 90.79999542236328, 88.19999694824219, 72.5999984741211, 68.5999984741211, 94.79999542236328, 86.79999542236328, 83.0, 80.4000015258789, 91.19999694824219, 89.19999694824219] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [101.79999542236328, 103.5999984741211, 109.79999542236328, 109.19999694824219, 108.79999542236328, 104.0, 93.4000015258789, 95.5999984741211, 96.19999694824219, 103.79999542236328, 109.0, 109.39999389648438, 108.5999984741211, 107.5999984741211, 111.19999694824219, 109.19999694824219] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 67.56+/- 5.716 (max: 111.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 105.1+/- 1.416 (max: 111.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 14.55+/- 1.087 (max: 21.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 83.05+/- 2.611 (max: 99.2) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 22.38+/- 0.9508 (max: 34.81) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 26.23+/- 0.5773 (max: 31.25) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.77+/- 0.4501 (max: 16.9) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 27.13+/- 0.7226 (max: 34.81) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6887+/- 0.05664 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9825+/- 0.003476 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.145+/- 0.01992 (max: 0.27) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9387+/- 0.01076 (max: 0.99) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 7.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 93.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 7.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 68.4 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.62 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 23.4 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.62 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 23.82 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.95 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.82 | +------------------------------------------------------------------------------------------------- +Evaluating PLR_SoftMoE_SEED3 against population in Overcooked-CoordRing6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [5.599999904632568, 5.599999904632568, 25.19999885559082, 21.399999618530273, 29.19999885559082, 28.399999618530273, 15.799999237060547, 19.19999885559082, 35.79999923706055, 40.0, 29.19999885559082, 30.399999618530273, 14.59999942779541, 17.19999885559082, 30.799999237060547, 31.19999885559082, 28.19999885559082, 26.799999237060547, 14.59999942779541, 11.800000190734863, 20.19999885559082, 21.799999237060547, 24.19999885559082, 23.600000381469727, 18.600000381469727, 18.799999237060547, 26.799999237060547, 29.399999618530273, 31.599998474121094, 32.0, 15.399999618530273, 14.0, 23.0, 27.599998474121094, 41.599998474121094, 44.20000076293945, 3.1999998092651367, 5.400000095367432, 18.19999885559082, 20.19999885559082, 22.0, 22.19999885559082, 16.0, 16.19999885559082, 17.799999237060547, 17.799999237060547, 20.0, 17.799999237060547] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [5.599999904632568, 5.599999904632568, 15.799999237060547, 19.19999885559082, 14.59999942779541, 17.19999885559082, 14.59999942779541, 11.800000190734863, 18.600000381469727, 18.799999237060547, 15.399999618530273, 14.0, 3.1999998092651367, 5.400000095367432, 16.0, 16.19999885559082] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [25.19999885559082, 21.399999618530273, 35.79999923706055, 40.0, 30.799999237060547, 31.19999885559082, 20.19999885559082, 21.799999237060547, 26.799999237060547, 29.399999618530273, 23.0, 27.599998474121094, 18.19999885559082, 20.19999885559082, 17.799999237060547, 17.799999237060547] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [29.19999885559082, 28.399999618530273, 29.19999885559082, 30.399999618530273, 28.19999885559082, 26.799999237060547, 24.19999885559082, 23.600000381469727, 31.599998474121094, 32.0, 41.599998474121094, 44.20000076293945, 22.0, 22.19999885559082, 20.0, 17.799999237060547] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 22.3+/- 1.31 (max: 44.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 28.21+/- 1.776 (max: 44.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 13.25+/- 1.329 (max: 19.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 25.45+/- 1.662 (max: 40.0) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 14.12+/- 0.3525 (max: 18.19) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 14.19+/- 0.4903 (max: 17.59) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.99+/- 0.6906 (max: 16.93) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 15.16+/- 0.541 (max: 18.19) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.3258+/- 0.03085 (max: 0.87) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.4612+/- 0.04632 (max: 0.87) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.1312+/- 0.02313 (max: 0.29) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.385+/- 0.04779 (max: 0.74) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 3.2 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 17.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 3.2 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 17.8 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 7.859 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 11.13 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 7.859 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.49 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.2 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.12 | +----------------------------------------------------------------------------------------------- +Evaluating PLR_SoftMoE_SEED3 against population in Overcooked-ForcedCoord6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [4.799999713897705, 0.19999998807907104, 8.199999809265137, 0.19999998807907104, 5.199999809265137, 0.19999998807907104, 5.599999904632568, 0.0, 4.799999713897705, 0.3999999761581421, 3.1999998092651367, 0.19999998807907104, 4.400000095367432, 0.19999998807907104, 4.0, 0.0, 3.5999999046325684, 0.0, 5.799999713897705, 0.19999998807907104, 3.799999952316284, 0.19999998807907104, 3.799999952316284, 0.7999999523162842, 4.400000095367432, 0.0, 6.0, 0.19999998807907104, 6.799999713897705, 0.3999999761581421, 3.5999999046325684, 0.3999999761581421, 3.3999998569488525, 0.19999998807907104, 2.0, 0.0, 3.0, 0.19999998807907104, 3.3999998569488525, 0.3999999761581421, 1.5999999046325684, 0.3999999761581421, 7.399999618530273, 0.19999998807907104, 8.0, 0.5999999642372131, 8.59999942779541, 0.7999999523162842] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [4.799999713897705, 0.19999998807907104, 5.599999904632568, 0.0, 4.400000095367432, 0.19999998807907104, 5.799999713897705, 0.19999998807907104, 4.400000095367432, 0.0, 3.5999999046325684, 0.3999999761581421, 3.0, 0.19999998807907104, 7.399999618530273, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [8.199999809265137, 0.19999998807907104, 4.799999713897705, 0.3999999761581421, 4.0, 0.0, 3.799999952316284, 0.19999998807907104, 6.0, 0.19999998807907104, 3.3999998569488525, 0.19999998807907104, 3.3999998569488525, 0.3999999761581421, 8.0, 0.5999999642372131] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [5.199999809265137, 0.19999998807907104, 3.1999998092651367, 0.19999998807907104, 3.5999999046325684, 0.0, 3.799999952316284, 0.7999999523162842, 6.799999713897705, 0.3999999761581421, 2.0, 0.0, 1.5999999046325684, 0.3999999761581421, 8.59999942779541, 0.7999999523162842] +--------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 2.537+/- 0.3839 (max: 8.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 2.35+/- 0.6607 (max: 8.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 2.525+/- 0.6514 (max: 7.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 2.737+/- 0.721 (max: 8.2) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 5.221+/- 0.5079 (max: 10.62) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 5.097+/- 0.8544 (max: 10.3) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 5.078+/- 0.93 (max: 9.831) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 5.487+/- 0.9074 (max: 10.62) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0025+/- 0.0008681 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0025+/- 0.001443 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.00125+/- 0.00125 (max: 0.02) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.00375+/- 0.001797 (max: 0.02) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +--------------------------------------------------------------------------------------------------- +Evaluating PLR_SoftMoE_SEED3 against population in Overcooked-CounterCircuit6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [7.0, 6.599999904632568, 8.399999618530273, 7.199999809265137, 6.599999904632568, 2.799999952316284, 4.0, 3.799999952316284, 7.199999809265137, 4.599999904632568, 1.0, 0.19999998807907104, 2.5999999046325684, 3.1999998092651367, 10.800000190734863, 6.799999713897705, 11.59999942779541, 5.0, 6.199999809265137, 2.799999952316284, 14.59999942779541, 13.59999942779541, 7.0, 1.5999999046325684, 5.400000095367432, 3.5999999046325684, 8.199999809265137, 5.799999713897705, 4.400000095367432, 1.399999976158142, 3.3999998569488525, 1.5999999046325684, 9.399999618530273, 5.199999809265137, 5.0, 2.3999998569488525, 3.3999998569488525, 1.7999999523162842, 8.0, 6.399999618530273, 3.1999998092651367, 1.0, 4.199999809265137, 2.5999999046325684, 10.59999942779541, 10.0, 5.599999904632568, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [7.0, 6.599999904632568, 4.0, 3.799999952316284, 2.5999999046325684, 3.1999998092651367, 6.199999809265137, 2.799999952316284, 5.400000095367432, 3.5999999046325684, 3.3999998569488525, 1.5999999046325684, 3.3999998569488525, 1.7999999523162842, 4.199999809265137, 2.5999999046325684] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [8.399999618530273, 7.199999809265137, 7.199999809265137, 4.599999904632568, 10.800000190734863, 6.799999713897705, 14.59999942779541, 13.59999942779541, 8.199999809265137, 5.799999713897705, 9.399999618530273, 5.199999809265137, 8.0, 6.399999618530273, 10.59999942779541, 10.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [6.599999904632568, 2.799999952316284, 1.0, 0.19999998807907104, 11.59999942779541, 5.0, 7.0, 1.5999999046325684, 4.400000095367432, 1.399999976158142, 5.0, 2.3999998569488525, 3.1999998092651367, 1.0, 5.599999904632568, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 5.371+/- 0.4917 (max: 14.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 3.675+/- 0.7696 (max: 11.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 3.887+/- 0.4078 (max: 7.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 8.55+/- 0.7082 (max: 14.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 8.672+/- 0.4168 (max: 15.86) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 6.968+/- 0.7983 (max: 11.38) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 7.776+/- 0.3347 (max: 9.95) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 11.27+/- 0.4302 (max: 15.86) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.01937+/- 0.004683 (max: 0.17) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.00875+/- 0.0034 (max: 0.04) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0025+/- 0.001118 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.04687+/- 0.01079 (max: 0.17) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 1.6 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 4.6 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 5.426 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 8.773 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PLR_SoftMoE_SEED3 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [1.0, 1.7999999523162842, 0.3999999761581421, 11.399999618530273, 1.399999976158142, 3.5999999046325684, 1.0, 5.799999713897705, 1.1999999284744263, 45.39999771118164, 1.5999999046325684, 11.800000190734863, 1.0, 31.599998474121094, 1.1999999284744263, 54.39999771118164, 1.1999999284744263, 9.399999618530273, 1.0, 18.399999618530273, 1.1999999284744263, 8.199999809265137, 0.7999999523162842, 2.3999998569488525, 0.7999999523162842, 2.799999952316284, 0.19999998807907104, 17.0, 0.5999999642372131, 4.599999904632568, 0.7999999523162842, 4.0, 0.5999999642372131, 28.799999237060547, 1.0, 8.399999618530273, 0.19999998807907104, 4.400000095367432, 1.0, 21.19999885559082, 1.0, 6.599999904632568, 0.7999999523162842, 4.199999809265137, 0.19999998807907104, 22.0, 1.5999999046325684, 3.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [1.0, 1.7999999523162842, 1.0, 5.799999713897705, 1.0, 31.599998474121094, 1.0, 18.399999618530273, 0.7999999523162842, 2.799999952316284, 0.7999999523162842, 4.0, 0.19999998807907104, 4.400000095367432, 0.7999999523162842, 4.199999809265137] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.3999999761581421, 11.399999618530273, 1.1999999284744263, 45.39999771118164, 1.1999999284744263, 54.39999771118164, 1.1999999284744263, 8.199999809265137, 0.19999998807907104, 17.0, 0.5999999642372131, 28.799999237060547, 1.0, 21.19999885559082, 0.19999998807907104, 22.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [1.399999976158142, 3.5999999046325684, 1.5999999046325684, 11.800000190734863, 1.1999999284744263, 9.399999618530273, 0.7999999523162842, 2.3999998569488525, 0.5999999642372131, 4.599999904632568, 1.0, 8.399999618530273, 1.0, 6.599999904632568, 1.5999999046325684, 3.0] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 7.354+/- 1.7 (max: 54.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 3.687+/- 0.8769 (max: 11.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 4.975+/- 2.086 (max: 31.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 13.4+/- 4.283 (max: 54.4) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 9.355+/- 1.051 (max: 32.26) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 7.904+/- 0.9881 (max: 16.52) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 7.804+/- 1.55 (max: 25.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 12.36+/- 2.471 (max: 32.26) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.08646+/- 0.02581 (max: 0.75) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.02562+/- 0.009397 (max: 0.12) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.05125+/- 0.03365 (max: 0.5) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1825+/- 0.06401 (max: 0.75) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.6 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.2 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 3.412 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 1.99 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating PLR_SoftMoE_SEED3 against population in Overcooked-CrampedRoom6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [13.59999942779541, 14.199999809265137, 66.0, 66.0, 87.79999542236328, 84.79999542236328, 19.399999618530273, 19.399999618530273, 66.0, 65.0, 85.19999694824219, 87.19999694824219, 22.600000381469727, 21.0, 81.19999694824219, 87.79999542236328, 93.0, 94.4000015258789, 19.799999237060547, 20.600000381469727, 78.0, 79.5999984741211, 69.5999984741211, 69.5999984741211, 15.0, 13.799999237060547, 67.4000015258789, 63.19999694824219, 77.5999984741211, 79.4000015258789, 19.19999885559082, 22.600000381469727, 83.19999694824219, 74.79999542236328, 96.19999694824219, 93.79999542236328, 20.600000381469727, 16.600000381469727, 75.5999984741211, 73.19999694824219, 90.0, 94.19999694824219, 22.399999618530273, 21.399999618530273, 82.79999542236328, 79.4000015258789, 83.5999984741211, 87.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [13.59999942779541, 14.199999809265137, 19.399999618530273, 19.399999618530273, 22.600000381469727, 21.0, 19.799999237060547, 20.600000381469727, 15.0, 13.799999237060547, 19.19999885559082, 22.600000381469727, 20.600000381469727, 16.600000381469727, 22.399999618530273, 21.399999618530273] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [66.0, 66.0, 66.0, 65.0, 81.19999694824219, 87.79999542236328, 78.0, 79.5999984741211, 67.4000015258789, 63.19999694824219, 83.19999694824219, 74.79999542236328, 75.5999984741211, 73.19999694824219, 82.79999542236328, 79.4000015258789] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [87.79999542236328, 84.79999542236328, 85.19999694824219, 87.19999694824219, 93.0, 94.4000015258789, 69.5999984741211, 69.5999984741211, 77.5999984741211, 79.4000015258789, 96.19999694824219, 93.79999542236328, 90.0, 94.19999694824219, 83.5999984741211, 87.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 59.68+/- 4.371 (max: 96.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 85.84+/- 2.074 (max: 96.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 18.89+/- 0.8015 (max: 22.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 74.32+/- 1.954 (max: 87.8) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 26.11+/- 1.206 (max: 39.14) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 33.15+/- 0.6656 (max: 39.14) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 14.98+/- 0.3526 (max: 17.29) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 30.19+/- 0.6522 (max: 34.48) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6881+/- 0.04774 (max: 0.99) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9256+/- 0.01411 (max: 0.99) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.2337+/- 0.02162 (max: 0.36) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.905+/- 0.009874 (max: 0.96) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 13.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 69.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 13.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 63.2 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 12.55 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 29.17 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 12.55 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 24.74 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.09 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.79 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.09 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.84 | +------------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED1 against population in Overcooked-CoordRing6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [14.799999237060547, 16.19999885559082, 44.79999923706055, 40.79999923706055, 20.399999618530273, 22.600000381469727, 25.799999237060547, 26.799999237060547, 45.20000076293945, 44.0, 22.0, 26.599998474121094, 23.799999237060547, 24.799999237060547, 41.20000076293945, 43.20000076293945, 21.799999237060547, 23.19999885559082, 37.599998474121094, 36.39999771118164, 40.0, 44.20000076293945, 25.799999237060547, 26.399999618530273, 22.0, 26.0, 28.19999885559082, 29.599998474121094, 19.799999237060547, 22.799999237060547, 26.19999885559082, 23.799999237060547, 48.79999923706055, 45.599998474121094, 51.19999694824219, 60.0, 15.799999237060547, 16.600000381469727, 37.599998474121094, 37.20000076293945, 39.599998474121094, 39.79999923706055, 16.399999618530273, 19.799999237060547, 30.19999885559082, 32.79999923706055, 27.799999237060547, 27.599998474121094] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [14.799999237060547, 16.19999885559082, 25.799999237060547, 26.799999237060547, 23.799999237060547, 24.799999237060547, 37.599998474121094, 36.39999771118164, 22.0, 26.0, 26.19999885559082, 23.799999237060547, 15.799999237060547, 16.600000381469727, 16.399999618530273, 19.799999237060547] +k eval/a1:test_return:Overcooked-CoordRing6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [44.79999923706055, 40.79999923706055, 45.20000076293945, 44.0, 41.20000076293945, 43.20000076293945, 40.0, 44.20000076293945, 28.19999885559082, 29.599998474121094, 48.79999923706055, 45.599998474121094, 37.599998474121094, 37.20000076293945, 30.19999885559082, 32.79999923706055] +k eval/a1:test_return:Overcooked-CoordRing6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [20.399999618530273, 22.600000381469727, 22.0, 26.599998474121094, 21.799999237060547, 23.19999885559082, 25.799999237060547, 26.399999618530273, 19.799999237060547, 22.799999237060547, 51.19999694824219, 60.0, 39.599998474121094, 39.79999923706055, 27.799999237060547, 27.599998474121094] +k eval/a1:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 30.91+/- 1.566 (max: 60.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 29.84+/- 2.937 (max: 60.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 23.3+/- 1.706 (max: 37.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 39.59+/- 1.596 (max: 48.8) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 16.96+/- 0.3158 (max: 23.38) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 17.3+/- 0.7194 (max: 23.38) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 16.28+/- 0.4104 (max: 18.63) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 17.32+/- 0.452 (max: 20.65) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.504+/- 0.03459 (max: 0.92) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.4494+/- 0.05703 (max: 0.92) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.3425+/- 0.04084 (max: 0.7) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.72+/- 0.03575 (max: 0.86) | +| eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 14.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 19.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 14.8 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 28.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 12.84 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 14.06 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.84 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 13.86 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.11 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.2 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.11 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.44 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED1 against population in Overcooked-ForcedCoord6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [6.399999618530273, 0.0, 7.599999904632568, 0.0, 1.7999999523162842, 0.0, 5.199999809265137, 0.0, 5.599999904632568, 0.0, 2.0, 0.0, 5.400000095367432, 0.0, 3.799999952316284, 0.0, 4.0, 0.0, 3.0, 0.0, 3.3999998569488525, 0.0, 1.399999976158142, 0.0, 4.199999809265137, 0.0, 5.199999809265137, 0.0, 5.599999904632568, 0.0, 2.799999952316284, 0.0, 2.200000047683716, 0.0, 0.19999998807907104, 0.0, 3.0, 0.0, 1.7999999523162842, 0.0, 2.5999999046325684, 0.0, 6.599999904632568, 0.0, 6.399999618530273, 0.0, 7.799999713897705, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [6.399999618530273, 0.0, 5.199999809265137, 0.0, 5.400000095367432, 0.0, 3.0, 0.0, 4.199999809265137, 0.0, 2.799999952316284, 0.0, 3.0, 0.0, 6.599999904632568, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [7.599999904632568, 0.0, 5.599999904632568, 0.0, 3.799999952316284, 0.0, 3.3999998569488525, 0.0, 5.199999809265137, 0.0, 2.200000047683716, 0.0, 1.7999999523162842, 0.0, 6.399999618530273, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [1.7999999523162842, 0.0, 2.0, 0.0, 4.0, 0.0, 1.399999976158142, 0.0, 5.599999904632568, 0.0, 0.19999998807907104, 0.0, 2.5999999046325684, 0.0, 7.799999713897705, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 2.042+/- 0.3644 (max: 7.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.587+/- 0.5912 (max: 7.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 2.287+/- 0.647 (max: 6.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 2.25+/- 0.6776 (max: 7.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 3.866+/- 0.5951 (max: 10.16) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 3.398+/- 0.9812 (max: 10.16) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 4.139+/- 1.083 (max: 9.404) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 4.061+/- 1.082 (max: 10.11) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.001458+/- 0.0005148 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0025+/- 0.001118 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED1 against population in Overcooked-CounterCircuit6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [14.799999237060547, 16.399999618530273, 18.399999618530273, 16.799999237060547, 7.0, 10.59999942779541, 10.800000190734863, 12.399999618530273, 13.399999618530273, 15.59999942779541, 3.799999952316284, 4.400000095367432, 5.599999904632568, 3.799999952316284, 7.0, 8.0, 11.199999809265137, 7.199999809265137, 11.800000190734863, 15.59999942779541, 19.799999237060547, 21.600000381469727, 11.800000190734863, 12.799999237060547, 10.800000190734863, 12.799999237060547, 20.399999618530273, 18.799999237060547, 8.199999809265137, 8.800000190734863, 13.0, 13.799999237060547, 13.799999237060547, 13.799999237060547, 8.0, 11.0, 10.59999942779541, 15.199999809265137, 15.0, 13.399999618530273, 11.0, 9.800000190734863, 12.0, 16.19999885559082, 21.600000381469727, 15.59999942779541, 11.800000190734863, 9.59999942779541] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [14.799999237060547, 16.399999618530273, 10.800000190734863, 12.399999618530273, 5.599999904632568, 3.799999952316284, 11.800000190734863, 15.59999942779541, 10.800000190734863, 12.799999237060547, 13.0, 13.799999237060547, 10.59999942779541, 15.199999809265137, 12.0, 16.19999885559082] +k eval/a1:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [18.399999618530273, 16.799999237060547, 13.399999618530273, 15.59999942779541, 7.0, 8.0, 19.799999237060547, 21.600000381469727, 20.399999618530273, 18.799999237060547, 13.799999237060547, 13.799999237060547, 15.0, 13.399999618530273, 21.600000381469727, 15.59999942779541] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [7.0, 10.59999942779541, 3.799999952316284, 4.400000095367432, 11.199999809265137, 7.199999809265137, 11.800000190734863, 12.799999237060547, 8.199999809265137, 8.800000190734863, 8.0, 11.0, 11.0, 9.800000190734863, 11.800000190734863, 9.59999942779541] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 12.41+/- 0.6386 (max: 21.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 9.187+/- 0.6554 (max: 12.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 12.22+/- 0.8788 (max: 16.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 15.81+/- 1.08 (max: 21.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 12.17+/- 0.3143 (max: 18.91) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 10.86+/- 0.354 (max: 13.59) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 11.73+/- 0.4026 (max: 13.97) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 13.91+/- 0.5584 (max: 18.91) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.08708+/- 0.01136 (max: 0.33) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0325+/- 0.007444 (max: 0.11) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.06812+/- 0.01212 (max: 0.17) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.1606+/- 0.02067 (max: 0.33) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 3.8 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 3.8 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 3.8 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 7.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 7.846 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 7.846 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 7.846 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 9.539 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED1 against population in Overcooked-AsymmAdvantages6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [1.7999999523162842, 37.79999923706055, 0.5999999642372131, 49.0, 1.5999999046325684, 80.0, 1.5999999046325684, 38.0, 0.5999999642372131, 74.79999542236328, 1.1999999284744263, 61.0, 1.7999999523162842, 70.19999694824219, 1.7999999523162842, 84.79999542236328, 1.399999976158142, 38.599998474121094, 1.0, 58.39999771118164, 3.5999999046325684, 63.79999923706055, 3.3999998569488525, 29.399999618530273, 0.5999999642372131, 30.599998474121094, 0.3999999761581421, 57.39999771118164, 0.7999999523162842, 51.39999771118164, 1.0, 36.20000076293945, 0.7999999523162842, 71.19999694824219, 1.1999999284744263, 66.79999542236328, 0.5999999642372131, 32.0, 0.5999999642372131, 73.0, 1.0, 29.0, 1.0, 36.20000076293945, 0.19999998807907104, 68.5999984741211, 3.5999999046325684, 55.39999771118164] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [1.7999999523162842, 37.79999923706055, 1.5999999046325684, 38.0, 1.7999999523162842, 70.19999694824219, 1.0, 58.39999771118164, 0.5999999642372131, 30.599998474121094, 1.0, 36.20000076293945, 0.5999999642372131, 32.0, 1.0, 36.20000076293945] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.5999999642372131, 49.0, 0.5999999642372131, 74.79999542236328, 1.7999999523162842, 84.79999542236328, 3.5999999046325684, 63.79999923706055, 0.3999999761581421, 57.39999771118164, 0.7999999523162842, 71.19999694824219, 0.5999999642372131, 73.0, 0.19999998807907104, 68.5999984741211] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [1.5999999046325684, 80.0, 1.1999999284744263, 61.0, 1.399999976158142, 38.599998474121094, 3.3999998569488525, 29.399999618530273, 0.7999999523162842, 51.39999771118164, 1.1999999284744263, 66.79999542236328, 1.0, 29.0, 3.5999999046325684, 55.39999771118164] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 27.62+/- 4.231 (max: 84.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 26.61+/- 7.127 (max: 80.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 21.8+/- 5.846 (max: 70.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 34.45+/- 8.824 (max: 84.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 17.5+/- 2.017 (max: 42.51) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 20.13+/- 4.016 (max: 42.51) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 14.07+/- 2.526 (max: 30.53) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 18.3+/- 3.803 (max: 41.32) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.3604+/- 0.05515 (max: 0.95) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.3331+/- 0.09095 (max: 0.88) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.3212+/- 0.0873 (max: 0.91) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.4269+/- 0.1107 (max: 0.95) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.3+/- 0.04376 (max: 0.6) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.3+/- 0.07746 (max: 0.6) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.3+/- 0.07746 (max: 0.6) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.3+/- 0.07746 (max: 0.6) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 1.706+/- 0.2488 (max: 3.412) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 1.706+/- 0.4405 (max: 3.412) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 1.706+/- 0.4405 (max: 3.412) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 1.706+/- 0.4405 (max: 3.412) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.8 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.6 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.2 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 3.919 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 3.412 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 1.99 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED1 against population in Overcooked-CrampedRoom6_9 for xpid paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [31.19999885559082, 28.799999237060547, 90.19999694824219, 94.79999542236328, 119.79999542236328, 125.39999389648438, 37.79999923706055, 38.39999771118164, 87.5999984741211, 94.0, 110.19999694824219, 111.5999984741211, 42.39999771118164, 37.39999771118164, 103.79999542236328, 106.39999389648438, 109.79999542236328, 112.19999694824219, 41.39999771118164, 41.20000076293945, 91.4000015258789, 96.0, 105.19999694824219, 103.79999542236328, 25.599998474121094, 26.0, 89.5999984741211, 88.4000015258789, 109.79999542236328, 113.19999694824219, 41.0, 39.599998474121094, 95.79999542236328, 103.39999389648438, 115.79999542236328, 106.19999694824219, 40.79999923706055, 40.39999771118164, 99.4000015258789, 97.5999984741211, 110.5999984741211, 117.19999694824219, 43.39999771118164, 43.599998474121094, 98.79999542236328, 95.0, 108.0, 111.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [31.19999885559082, 28.799999237060547, 37.79999923706055, 38.39999771118164, 42.39999771118164, 37.39999771118164, 41.39999771118164, 41.20000076293945, 25.599998474121094, 26.0, 41.0, 39.599998474121094, 40.79999923706055, 40.39999771118164, 43.39999771118164, 43.599998474121094] +k eval/a1:test_return:Overcooked-CrampedRoom6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [90.19999694824219, 94.79999542236328, 87.5999984741211, 94.0, 103.79999542236328, 106.39999389648438, 91.4000015258789, 96.0, 89.5999984741211, 88.4000015258789, 95.79999542236328, 103.39999389648438, 99.4000015258789, 97.5999984741211, 98.79999542236328, 95.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [119.79999542236328, 125.39999389648438, 110.19999694824219, 111.5999984741211, 109.79999542236328, 112.19999694824219, 105.19999694824219, 103.79999542236328, 109.79999542236328, 113.19999694824219, 115.79999542236328, 106.19999694824219, 110.5999984741211, 117.19999694824219, 108.0, 111.79999542236328] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 81.7+/- 4.736 (max: 125.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 111.9+/- 1.387 (max: 125.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 37.44+/- 1.515 (max: 43.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 95.76+/- 1.407 (max: 106.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 27.63+/- 0.8489 (max: 40.17) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 33.46+/- 1.016 (max: 40.17) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 21.09+/- 0.447 (max: 23.66) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 28.35+/- 0.6247 (max: 34.93) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.8594+/- 0.02583 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.98+/- 0.004564 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.6244+/- 0.02644 (max: 0.74) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9737+/- 0.005313 (max: 1.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 25.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 103.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 25.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 87.6 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 18.44 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 27.13 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 18.44 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 25.24 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.41 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.94 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.41 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.92 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED2 against population in Overcooked-CoordRing6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [9.399999618530273, 11.800000190734863, 28.19999885559082, 27.0, 23.0, 28.19999885559082, 19.600000381469727, 23.0, 39.0, 44.599998474121094, 22.799999237060547, 34.599998474121094, 20.19999885559082, 18.0, 32.0, 37.79999923706055, 26.0, 33.599998474121094, 21.19999885559082, 22.799999237060547, 23.0, 26.799999237060547, 13.59999942779541, 28.799999237060547, 21.0, 22.799999237060547, 34.79999923706055, 37.0, 27.19999885559082, 35.39999771118164, 16.600000381469727, 20.600000381469727, 26.19999885559082, 31.19999885559082, 30.19999885559082, 30.19999885559082, 6.599999904632568, 9.0, 22.399999618530273, 25.399999618530273, 21.19999885559082, 24.19999885559082, 21.799999237060547, 20.399999618530273, 21.600000381469727, 22.600000381469727, 17.19999885559082, 16.19999885559082] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [9.399999618530273, 11.800000190734863, 19.600000381469727, 23.0, 20.19999885559082, 18.0, 21.19999885559082, 22.799999237060547, 21.0, 22.799999237060547, 16.600000381469727, 20.600000381469727, 6.599999904632568, 9.0, 21.799999237060547, 20.399999618530273] +k eval/a1:test_return:Overcooked-CoordRing6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [28.19999885559082, 27.0, 39.0, 44.599998474121094, 32.0, 37.79999923706055, 23.0, 26.799999237060547, 34.79999923706055, 37.0, 26.19999885559082, 31.19999885559082, 22.399999618530273, 25.399999618530273, 21.600000381469727, 22.600000381469727] +k eval/a1:test_return:Overcooked-CoordRing6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [23.0, 28.19999885559082, 22.799999237060547, 34.599998474121094, 26.0, 33.599998474121094, 13.59999942779541, 28.799999237060547, 27.19999885559082, 35.39999771118164, 30.19999885559082, 30.19999885559082, 21.19999885559082, 24.19999885559082, 17.19999885559082, 16.19999885559082] +k eval/a1:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 24.52+/- 1.16 (max: 44.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 25.77+/- 1.635 (max: 35.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 17.8+/- 1.368 (max: 23.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 29.97+/- 1.735 (max: 44.6) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 17.12+/- 0.4028 (max: 24.08) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 17.2+/- 0.6717 (max: 24.08) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 15.23+/- 0.5825 (max: 17.67) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 18.94+/- 0.5247 (max: 21.88) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.3633+/- 0.02498 (max: 0.75) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.3931+/- 0.03747 (max: 0.65) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.2212+/- 0.02772 (max: 0.35) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.4756+/- 0.03807 (max: 0.75) | +| eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 6.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 13.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 6.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 21.6 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 10.22 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 13.86 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 10.22 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 15.91 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.16 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.27 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED2 against population in Overcooked-ForcedCoord6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [4.799999713897705, 0.7999999523162842, 5.799999713897705, 0.5999999642372131, 2.3999998569488525, 0.19999998807907104, 2.0, 0.0, 1.5999999046325684, 0.19999998807907104, 1.5999999046325684, 0.0, 3.5999999046325684, 1.5999999046325684, 3.3999998569488525, 0.3999999761581421, 3.1999998092651367, 1.1999999284744263, 3.3999998569488525, 0.5999999642372131, 2.200000047683716, 0.0, 2.0, 0.0, 4.599999904632568, 1.0, 5.599999904632568, 1.5999999046325684, 4.799999713897705, 0.3999999761581421, 2.3999998569488525, 0.3999999761581421, 2.3999998569488525, 0.3999999761581421, 0.7999999523162842, 0.19999998807907104, 1.7999999523162842, 0.19999998807907104, 2.0, 0.3999999761581421, 1.0, 0.3999999761581421, 3.0, 0.19999998807907104, 3.799999952316284, 0.19999998807907104, 7.0, 0.3999999761581421] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [4.799999713897705, 0.7999999523162842, 2.0, 0.0, 3.5999999046325684, 1.5999999046325684, 3.3999998569488525, 0.5999999642372131, 4.599999904632568, 1.0, 2.3999998569488525, 0.3999999761581421, 1.7999999523162842, 0.19999998807907104, 3.0, 0.19999998807907104] +k eval/a1:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [5.799999713897705, 0.5999999642372131, 1.5999999046325684, 0.19999998807907104, 3.3999998569488525, 0.3999999761581421, 2.200000047683716, 0.0, 5.599999904632568, 1.5999999046325684, 2.3999998569488525, 0.3999999761581421, 2.0, 0.3999999761581421, 3.799999952316284, 0.19999998807907104] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [2.3999998569488525, 0.19999998807907104, 1.5999999046325684, 0.0, 3.1999998092651367, 1.1999999284744263, 2.0, 0.0, 4.799999713897705, 0.3999999761581421, 0.7999999523162842, 0.19999998807907104, 1.0, 0.3999999761581421, 7.0, 0.3999999761581421] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +---------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.804+/- 0.2571 (max: 7.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.6+/- 0.4885 (max: 7.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.9+/- 0.3967 (max: 4.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 1.912+/- 0.4704 (max: 5.8) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 4.829+/- 0.3967 (max: 10.72) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 4.404+/- 0.7514 (max: 10.72) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 5.159+/- 0.6541 (max: 8.879) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 4.926+/- 0.6835 (max: 9.075) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.00125+/- 0.0007062 (max: 0.03) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0025+/- 0.001936 (max: 0.03) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +---------------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED2 against population in Overcooked-CounterCircuit6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [15.59999942779541, 9.399999618530273, 22.19999885559082, 14.59999942779541, 11.59999942779541, 11.59999942779541, 7.799999713897705, 9.800000190734863, 13.799999237060547, 12.0, 4.799999713897705, 1.7999999523162842, 8.0, 5.0, 13.59999942779541, 10.0, 10.800000190734863, 11.0, 11.399999618530273, 8.59999942779541, 24.399999618530273, 21.799999237060547, 12.199999809265137, 6.799999713897705, 9.399999618530273, 7.799999713897705, 16.0, 13.799999237060547, 8.199999809265137, 5.0, 14.799999237060547, 8.199999809265137, 20.0, 11.199999809265137, 8.800000190734863, 7.399999618530273, 10.399999618530273, 9.0, 20.399999618530273, 15.199999809265137, 12.399999618530273, 8.59999942779541, 13.199999809265137, 8.800000190734863, 20.0, 17.799999237060547, 11.0, 5.799999713897705] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [15.59999942779541, 9.399999618530273, 7.799999713897705, 9.800000190734863, 8.0, 5.0, 11.399999618530273, 8.59999942779541, 9.399999618530273, 7.799999713897705, 14.799999237060547, 8.199999809265137, 10.399999618530273, 9.0, 13.199999809265137, 8.800000190734863] +k eval/a1:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [22.19999885559082, 14.59999942779541, 13.799999237060547, 12.0, 13.59999942779541, 10.0, 24.399999618530273, 21.799999237060547, 16.0, 13.799999237060547, 20.0, 11.199999809265137, 20.399999618530273, 15.199999809265137, 20.0, 17.799999237060547] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [11.59999942779541, 11.59999942779541, 4.799999713897705, 1.7999999523162842, 10.800000190734863, 11.0, 12.199999809265137, 6.799999713897705, 8.199999809265137, 5.0, 8.800000190734863, 7.399999618530273, 12.399999618530273, 8.59999942779541, 11.0, 5.799999713897705] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 11.7+/- 0.7137 (max: 24.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 8.612+/- 0.7856 (max: 12.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 9.825+/- 0.6872 (max: 15.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 16.67+/- 1.086 (max: 24.4) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 12.1+/- 0.3574 (max: 18.74) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 10.57+/- 0.4904 (max: 14.54) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 11.2+/- 0.2517 (max: 12.83) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 14.52+/- 0.5448 (max: 18.74) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.08729+/- 0.01393 (max: 0.34) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.03062+/- 0.007498 (max: 0.1) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.04062+/- 0.007442 (max: 0.12) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.1906+/- 0.02534 (max: 0.34) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 1.8 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 1.8 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 5.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 10.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 5.724 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 5.724 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 8.66 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 11.27 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.04 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PAIRED_SoftMoE_SEED2 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [18.19999885559082, 8.199999809265137, 6.0, 17.799999237060547, 36.79999923706055, 3.799999952316284, 19.600000381469727, 12.59999942779541, 24.799999237060547, 46.39999771118164, 26.19999885559082, 2.5999999046325684, 26.799999237060547, 37.20000076293945, 30.19999885559082, 64.79999542236328, 37.20000076293945, 15.799999237060547, 19.19999885559082, 22.0, 23.600000381469727, 11.199999809265137, 39.20000076293945, 6.599999904632568, 22.0, 8.59999942779541, 19.799999237060547, 20.600000381469727, 15.59999942779541, 8.399999618530273, 21.0, 6.799999713897705, 18.0, 31.799999237060547, 37.0, 3.3999998569488525, 20.600000381469727, 8.199999809265137, 25.19999885559082, 16.399999618530273, 46.39999771118164, 7.399999618530273, 17.0, 10.399999618530273, 10.399999618530273, 25.0, 35.39999771118164, 6.599999904632568] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [18.19999885559082, 8.199999809265137, 19.600000381469727, 12.59999942779541, 26.799999237060547, 37.20000076293945, 19.19999885559082, 22.0, 22.0, 8.59999942779541, 21.0, 6.799999713897705, 20.600000381469727, 8.199999809265137, 17.0, 10.399999618530273] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9, v [1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low, v [1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [6.0, 17.799999237060547, 24.799999237060547, 46.39999771118164, 30.19999885559082, 64.79999542236328, 23.600000381469727, 11.199999809265137, 19.799999237060547, 20.600000381469727, 18.0, 31.799999237060547, 25.19999885559082, 16.399999618530273, 10.399999618530273, 25.0] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid, v [1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [36.79999923706055, 3.799999952316284, 26.19999885559082, 2.5999999046325684, 37.20000076293945, 15.799999237060547, 39.20000076293945, 6.599999904632568, 15.59999942779541, 8.399999618530273, 37.0, 3.3999998569488525, 46.39999771118164, 7.399999618530273, 35.39999771118164, 6.599999904632568] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high, v [1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.19999998807907104] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 20.81+/- 1.909 (max: 64.8) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 20.52+/- 3.944 (max: 46.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 17.4+/- 2.027 (max: 37.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 24.5+/- 3.593 (max: 64.8) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 16.82+/- 0.7477 (max: 31.89) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 15.61+/- 1.344 (max: 23.13) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 15.42+/- 1.057 (max: 28.99) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 19.44+/- 1.284 (max: 31.89) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.2815+/- 0.033 (max: 0.85) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.3006+/- 0.07319 (max: 0.75) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.2075+/- 0.03668 (max: 0.54) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.3362+/- 0.05425 (max: 0.85) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.6+/- 0.05835 (max: 1.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.6+/- 0.1033 (max: 1.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.6+/- 0.1033 (max: 1.0) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.6+/- 0.1033 (max: 1.0) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 3.174+/- 0.1728 (max: 4.359) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 3.174+/- 0.3058 (max: 4.359) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 3.174+/- 0.3058 (max: 4.359) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 3.174+/- 0.3058 (max: 4.359) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 2.6 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 2.6 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 6.8 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 6.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 7.297 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 7.297 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 9.887 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 10.77 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.01 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.01 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.01 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.04 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.2 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.2 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.2 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.2 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 1.99 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 1.99 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 1.99 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 1.99 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED2 against population in Overcooked-CrampedRoom6_9 for xpid SEED_2_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [21.600000381469727, 15.799999237060547, 85.4000015258789, 87.0, 112.19999694824219, 116.39999389648438, 29.599998474121094, 26.399999618530273, 80.19999694824219, 80.5999984741211, 106.79999542236328, 105.19999694824219, 29.0, 27.399999618530273, 104.39999389648438, 98.5999984741211, 111.19999694824219, 112.5999984741211, 25.19999885559082, 29.599998474121094, 94.5999984741211, 93.5999984741211, 97.19999694824219, 101.0, 18.19999885559082, 22.0, 84.19999694824219, 89.0, 112.0, 109.5999984741211, 31.0, 30.599998474121094, 96.0, 96.19999694824219, 107.19999694824219, 112.39999389648438, 30.0, 29.599998474121094, 92.5999984741211, 87.79999542236328, 109.0, 109.39999389648438, 33.39999771118164, 34.20000076293945, 99.79999542236328, 96.79999542236328, 106.19999694824219, 104.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [21.600000381469727, 15.799999237060547, 29.599998474121094, 26.399999618530273, 29.0, 27.399999618530273, 25.19999885559082, 29.599998474121094, 18.19999885559082, 22.0, 31.0, 30.599998474121094, 30.0, 29.599998474121094, 33.39999771118164, 34.20000076293945] +k eval/a1:test_return:Overcooked-CrampedRoom6_9, v [0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:low, v [0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [85.4000015258789, 87.0, 80.19999694824219, 80.5999984741211, 104.39999389648438, 98.5999984741211, 94.5999984741211, 93.5999984741211, 84.19999694824219, 89.0, 96.0, 96.19999694824219, 92.5999984741211, 87.79999542236328, 99.79999542236328, 96.79999542236328] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:mid, v [0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [112.19999694824219, 116.39999389648438, 106.79999542236328, 105.19999694824219, 111.19999694824219, 112.5999984741211, 97.19999694824219, 101.0, 112.0, 109.5999984741211, 107.19999694824219, 112.39999389648438, 109.0, 109.39999389648438, 106.19999694824219, 104.79999542236328] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:high, v [0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.3999999761581421] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 75.7+/- 5.175 (max: 116.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 108.3+/- 1.204 (max: 116.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 27.1+/- 1.316 (max: 34.2) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 91.67+/- 1.765 (max: 104.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 25.4+/- 0.8028 (max: 39.24) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 28.66+/- 0.9033 (max: 36.0) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 19.05+/- 0.4251 (max: 22.21) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 28.49+/- 1.016 (max: 39.24) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.7927+/- 0.03853 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9812+/- 0.005692 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.4331+/- 0.02993 (max: 0.57) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9637+/- 0.009349 (max: 1.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.2+/- 0.02917 (max: 0.4) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.2+/- 0.05164 (max: 0.4) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.2+/- 0.05164 (max: 0.4) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.2+/- 0.05164 (max: 0.4) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 1.4+/- 0.2042 (max: 2.8) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 1.4+/- 0.3615 (max: 2.8) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 1.4+/- 0.3615 (max: 2.8) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 1.4+/- 0.3615 (max: 2.8) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 15.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 97.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 15.8 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 80.2 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 15.52 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 22.51 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 15.52 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 22.95 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.2 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.92 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.2 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.85 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED3 against population in Overcooked-CoordRing6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [12.399999618530273, 11.0, 41.20000076293945, 36.0, 35.79999923706055, 40.39999771118164, 25.799999237060547, 26.399999618530273, 48.599998474121094, 53.39999771118164, 43.20000076293945, 48.0, 25.799999237060547, 24.600000381469727, 48.20000076293945, 49.79999923706055, 50.39999771118164, 54.79999923706055, 38.0, 35.20000076293945, 38.599998474121094, 37.0, 34.79999923706055, 38.0, 24.19999885559082, 28.799999237060547, 37.79999923706055, 40.0, 46.599998474121094, 49.599998474121094, 23.0, 24.799999237060547, 44.0, 43.79999923706055, 33.79999923706055, 43.39999771118164, 9.0, 7.599999904632568, 35.20000076293945, 27.399999618530273, 31.0, 34.599998474121094, 20.0, 24.600000381469727, 24.19999885559082, 22.799999237060547, 23.799999237060547, 22.600000381469727] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [12.399999618530273, 11.0, 25.799999237060547, 26.399999618530273, 25.799999237060547, 24.600000381469727, 38.0, 35.20000076293945, 24.19999885559082, 28.799999237060547, 23.0, 24.799999237060547, 9.0, 7.599999904632568, 20.0, 24.600000381469727] +k eval/a1:test_return:Overcooked-CoordRing6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CoordRing6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [41.20000076293945, 36.0, 48.599998474121094, 53.39999771118164, 48.20000076293945, 49.79999923706055, 38.599998474121094, 37.0, 37.79999923706055, 40.0, 44.0, 43.79999923706055, 35.20000076293945, 27.399999618530273, 24.19999885559082, 22.799999237060547] +k eval/a1:test_return:Overcooked-CoordRing6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [35.79999923706055, 40.39999771118164, 43.20000076293945, 48.0, 50.39999771118164, 54.79999923706055, 34.79999923706055, 38.0, 46.599998474121094, 49.599998474121094, 33.79999923706055, 43.39999771118164, 31.0, 34.599998474121094, 23.799999237060547, 22.600000381469727] +k eval/a1:test_return:Overcooked-CoordRing6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 33.75+/- 1.716 (max: 54.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 39.42+/- 2.333 (max: 54.8) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 22.57+/- 2.18 (max: 38.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 39.25+/- 2.235 (max: 53.4) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 16.77+/- 0.4238 (max: 25.72) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 19.48+/- 0.748 (max: 25.72) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 14.93+/- 0.4341 (max: 17.04) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 15.89+/- 0.4128 (max: 18.76) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.5819+/- 0.03686 (max: 0.95) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.68+/- 0.04593 (max: 0.93) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.3444+/- 0.05052 (max: 0.73) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.7212+/- 0.0491 (max: 0.95) | +| eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 7.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 22.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 7.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 22.8 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 11.79 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 14.55 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 11.79 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 13.11 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.05 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.31 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.05 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.32 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CoordRing6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED3 against population in Overcooked-ForcedCoord6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.19999998807907104, 1.0, 0.19999998807907104, 2.5999999046325684, 0.0, 1.0, 0.0, 3.0, 0.19999998807907104, 2.799999952316284, 0.0, 0.7999999523162842, 0.3999999761581421, 2.3999998569488525, 0.19999998807907104, 3.0, 0.0, 3.0, 0.0, 2.0, 0.3999999761581421, 1.1999999284744263, 0.19999998807907104, 1.0, 0.19999998807907104, 1.1999999284744263, 0.5999999642372131, 0.7999999523162842, 0.7999999523162842, 0.19999998807907104, 0.0, 3.0, 0.0, 2.0, 0.0, 2.5999999046325684, 0.0, 2.0, 0.19999998807907104, 1.0, 0.0, 0.5999999642372131, 0.0, 1.5999999046325684, 0.0, 2.5999999046325684, 0.0, 0.5999999642372131] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.19999998807907104, 1.0, 0.0, 3.0, 0.3999999761581421, 2.3999998569488525, 0.0, 2.0, 0.19999998807907104, 1.1999999284744263, 0.0, 3.0, 0.0, 2.0, 0.0, 1.5999999046325684] +k eval/a1:test_return:Overcooked-ForcedCoord6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.19999998807907104, 2.5999999046325684, 0.19999998807907104, 2.799999952316284, 0.19999998807907104, 3.0, 0.3999999761581421, 1.1999999284744263, 0.5999999642372131, 0.7999999523162842, 0.0, 2.0, 0.19999998807907104, 1.0, 0.0, 2.5999999046325684] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 1.0, 0.0, 0.7999999523162842, 0.0, 3.0, 0.19999998807907104, 1.0, 0.7999999523162842, 0.19999998807907104, 0.0, 2.5999999046325684, 0.0, 0.5999999642372131, 0.0, 0.5999999642372131] +k eval/a1:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.95+/- 0.1513 (max: 3.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.675+/- 0.2287 (max: 3.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.062+/- 0.2809 (max: 3.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 1.112+/- 0.2763 (max: 3.0) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 3.264+/- 0.3728 (max: 7.141) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 2.577+/- 0.6121 (max: 7.141) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 3.421+/- 0.7182 (max: 7.141) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 3.796+/- 0.6033 (max: 7.141) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0002083+/- 0.0002083 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.000625+/- 0.000625 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating PAIRED_SoftMoE_SEED3 against population in Overcooked-CounterCircuit6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [3.5999999046325684, 2.5999999046325684, 8.59999942779541, 3.5999999046325684, 16.19999885559082, 13.399999618530273, 1.7999999523162842, 0.5999999642372131, 6.199999809265137, 2.5999999046325684, 9.59999942779541, 4.599999904632568, 12.59999942779541, 13.399999618530273, 14.399999618530273, 9.59999942779541, 33.79999923706055, 42.79999923706055, 6.599999904632568, 3.799999952316284, 13.0, 8.199999809265137, 15.399999618530273, 8.800000190734863, 5.199999809265137, 3.799999952316284, 18.0, 12.399999618530273, 15.399999618530273, 9.0, 8.399999618530273, 5.599999904632568, 11.800000190734863, 6.0, 13.0, 7.0, 3.0, 0.0, 4.799999713897705, 1.0, 14.799999237060547, 9.800000190734863, 6.0, 3.0, 5.799999713897705, 3.1999998092651367, 11.199999809265137, 8.59999942779541] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [3.5999999046325684, 2.5999999046325684, 1.7999999523162842, 0.5999999642372131, 12.59999942779541, 13.399999618530273, 6.599999904632568, 3.799999952316284, 5.199999809265137, 3.799999952316284, 8.399999618530273, 5.599999904632568, 3.0, 0.0, 6.0, 3.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:low, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [8.59999942779541, 3.5999999046325684, 6.199999809265137, 2.5999999046325684, 14.399999618530273, 9.59999942779541, 13.0, 8.199999809265137, 18.0, 12.399999618530273, 11.800000190734863, 6.0, 4.799999713897705, 1.0, 5.799999713897705, 3.1999998092651367] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:mid, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [16.19999885559082, 13.399999618530273, 9.59999942779541, 4.599999904632568, 33.79999923706055, 42.79999923706055, 15.399999618530273, 8.800000190734863, 15.399999618530273, 9.0, 13.0, 7.0, 14.799999237060547, 9.800000190734863, 11.199999809265137, 8.59999942779541] +k eval/a1:test_return:Overcooked-CounterCircuit6_9:high, v [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 9.221+/- 1.118 (max: 42.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 14.59+/- 2.492 (max: 42.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 5.0+/- 0.951 (max: 13.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 8.075+/- 1.2 (max: 18.0) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 10.17+/- 0.4812 (max: 19.74) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 11.86+/- 0.7161 (max: 19.74) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 8.197+/- 0.8849 (max: 14.71) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 10.46+/- 0.6527 (max: 14.55) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.06312+/- 0.02012 (max: 0.77) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.12+/- 0.05657 (max: 0.77) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.02062+/- 0.01047 (max: 0.14) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.04875+/- 0.01214 (max: 0.16) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 4.6 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 1.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 8.879 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 4.359 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CounterCircuit6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED3 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [14.799999237060547, 23.19999885559082, 16.19999885559082, 32.0, 20.0, 34.39999771118164, 15.199999809265137, 20.0, 19.600000381469727, 58.599998474121094, 20.19999885559082, 31.19999885559082, 19.600000381469727, 51.0, 16.799999237060547, 71.5999984741211, 25.599998474121094, 16.399999618530273, 16.399999618530273, 37.20000076293945, 16.600000381469727, 34.0, 24.399999618530273, 17.0, 11.800000190734863, 18.799999237060547, 17.600000381469727, 35.20000076293945, 16.600000381469727, 23.799999237060547, 13.0, 21.600000381469727, 17.799999237060547, 47.20000076293945, 23.600000381469727, 25.399999618530273, 14.0, 16.399999618530273, 16.799999237060547, 40.79999923706055, 25.799999237060547, 18.0, 12.799999237060547, 23.399999618530273, 19.799999237060547, 42.79999923706055, 23.0, 27.799999237060547] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [14.799999237060547, 23.19999885559082, 15.199999809265137, 20.0, 19.600000381469727, 51.0, 16.399999618530273, 37.20000076293945, 11.800000190734863, 18.799999237060547, 13.0, 21.600000381469727, 14.0, 16.399999618530273, 12.799999237060547, 23.399999618530273] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9, v [0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [16.19999885559082, 32.0, 19.600000381469727, 58.599998474121094, 16.799999237060547, 71.5999984741211, 16.600000381469727, 34.0, 17.600000381469727, 35.20000076293945, 17.799999237060547, 47.20000076293945, 16.799999237060547, 40.79999923706055, 19.799999237060547, 42.79999923706055] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [20.0, 34.39999771118164, 20.19999885559082, 31.19999885559082, 25.599998474121094, 16.399999618530273, 24.399999618530273, 17.0, 16.600000381469727, 23.799999237060547, 23.600000381469727, 25.399999618530273, 25.799999237060547, 18.0, 23.0, 27.799999237060547] +k eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263, 0.0, 1.1999999284744263] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 25.12+/- 1.806 (max: 71.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 23.32+/- 1.298 (max: 34.4) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 20.57+/- 2.548 (max: 51.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 31.46+/- 4.268 (max: 71.6) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 18.22+/- 0.8 (max: 34.55) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 17.81+/- 0.822 (max: 24.99) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 16.37+/- 1.116 (max: 27.48) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 20.47+/- 1.876 (max: 34.55) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.3327+/- 0.03005 (max: 0.87) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.3319+/- 0.02487 (max: 0.54) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.2475+/- 0.0495 (max: 0.76) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.4187+/- 0.06664 (max: 0.87) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.6+/- 0.08752 (max: 1.2) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.6+/- 0.1549 (max: 1.2) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.6+/- 0.1549 (max: 1.2) | +| eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.6+/- 0.1549 (max: 1.2) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 2.375+/- 0.3464 (max: 4.75) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 2.375+/- 0.6132 (max: 4.75) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 2.375+/- 0.6132 (max: 4.75) | +| eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 2.375+/- 0.6132 (max: 4.75) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 11.8 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 16.4 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 11.8 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 16.2 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 11.66 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 13.56 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 11.84 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 11.66 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.06 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.21 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.06 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.16 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating PAIRED_SoftMoE_SEED3 against population in Overcooked-CrampedRoom6_9 for xpid SEED_3_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [45.39999771118164, 47.0, 98.4000015258789, 95.19999694824219, 107.19999694824219, 114.79999542236328, 60.39999771118164, 59.19999694824219, 101.5999984741211, 106.5999984741211, 113.79999542236328, 115.19999694824219, 58.0, 54.0, 116.0, 117.5999984741211, 94.0, 104.79999542236328, 56.19999694824219, 58.39999771118164, 104.5999984741211, 96.5999984741211, 80.5999984741211, 93.0, 45.599998474121094, 48.39999771118164, 101.79999542236328, 100.4000015258789, 95.79999542236328, 117.79999542236328, 53.19999694824219, 50.599998474121094, 114.19999694824219, 107.39999389648438, 104.39999389648438, 109.39999389648438, 62.599998474121094, 61.79999923706055, 103.79999542236328, 104.0, 117.19999694824219, 112.39999389648438, 67.0, 60.0, 118.39999389648438, 109.5999984741211, 81.19999694824219, 96.4000015258789] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [45.39999771118164, 47.0, 60.39999771118164, 59.19999694824219, 58.0, 54.0, 56.19999694824219, 58.39999771118164, 45.599998474121094, 48.39999771118164, 53.19999694824219, 50.599998474121094, 62.599998474121094, 61.79999923706055, 67.0, 60.0] +k eval/a1:test_return:Overcooked-CrampedRoom6_9, v [0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:low, v [0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [98.4000015258789, 95.19999694824219, 101.5999984741211, 106.5999984741211, 116.0, 117.5999984741211, 104.5999984741211, 96.5999984741211, 101.79999542236328, 100.4000015258789, 114.19999694824219, 107.39999389648438, 103.79999542236328, 104.0, 118.39999389648438, 109.5999984741211] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:mid, v [0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [107.19999694824219, 114.79999542236328, 113.79999542236328, 115.19999694824219, 94.0, 104.79999542236328, 80.5999984741211, 93.0, 95.79999542236328, 117.79999542236328, 104.39999389648438, 109.39999389648438, 117.19999694824219, 112.39999389648438, 81.19999694824219, 96.4000015258789] +k eval/a1:test_return:Overcooked-CrampedRoom6_9:high, v [0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863, 0.0, 8.800000190734863] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 88.37+/- 3.628 (max: 118.4) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 103.6+/- 3.045 (max: 117.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 55.49+/- 1.645 (max: 67.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 106.0+/- 1.84 (max: 118.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 30.95+/- 1.134 (max: 51.5) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 40.5+/- 1.371 (max: 51.5) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 24.7+/- 0.3532 (max: 27.13) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 27.65+/- 0.8123 (max: 32.75) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.9142+/- 0.01214 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9225+/- 0.0158 (max: 0.99) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.8287+/- 0.01557 (max: 0.92) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9912+/- 0.002394 (max: 1.0) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9 | 4.4+/- 0.6418 (max: 8.8) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 4.4+/- 1.136 (max: 8.8) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 4.4+/- 1.136 (max: 8.8) | +| eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 4.4+/- 1.136 (max: 8.8) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 5.161+/- 0.7529 (max: 10.32) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 5.161+/- 1.333 (max: 10.32) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 5.161+/- 1.333 (max: 10.32) | +| eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 5.161+/- 1.333 (max: 10.32) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.005+/- 0.0007293 (max: 0.01) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.005+/- 0.001291 (max: 0.01) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.005+/- 0.001291 (max: 0.01) | +| eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.005+/- 0.001291 (max: 0.01) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 45.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 80.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 45.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 95.2 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 21.19 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 31.24 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 21.24 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 21.19 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.71 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.78 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.71 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.97 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_return_std:Overcooked-CrampedRoom6_9:mid | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.0 | +| min:eval/a1:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.0 | +-------------------------------------------------------------------------------------------------- +Evaluating ACCEL_SoftMoE_SEED1 against population in Overcooked-CoordRing6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [4.799999713897705, 2.3999998569488525, 19.0, 18.600000381469727, 24.600000381469727, 25.799999237060547, 15.0, 14.0, 29.399999618530273, 33.0, 26.399999618530273, 29.0, 12.799999237060547, 12.199999809265137, 24.799999237060547, 28.0, 28.19999885559082, 30.0, 14.799999237060547, 11.199999809265137, 18.799999237060547, 14.59999942779541, 23.0, 22.399999618530273, 16.19999885559082, 15.59999942779541, 20.399999618530273, 22.799999237060547, 28.0, 31.0, 14.59999942779541, 12.0, 25.0, 20.0, 38.0, 39.39999771118164, 3.1999998092651367, 4.400000095367432, 19.0, 19.19999885559082, 36.0, 34.39999771118164, 15.399999618530273, 13.799999237060547, 16.399999618530273, 13.799999237060547, 25.599998474121094, 25.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [4.799999713897705, 2.3999998569488525, 15.0, 14.0, 12.799999237060547, 12.199999809265137, 14.799999237060547, 11.199999809265137, 16.19999885559082, 15.59999942779541, 14.59999942779541, 12.0, 3.1999998092651367, 4.400000095367432, 15.399999618530273, 13.799999237060547] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [19.0, 18.600000381469727, 29.399999618530273, 33.0, 24.799999237060547, 28.0, 18.799999237060547, 14.59999942779541, 20.399999618530273, 22.799999237060547, 25.0, 20.0, 19.0, 19.19999885559082, 16.399999618530273, 13.799999237060547] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [24.600000381469727, 25.799999237060547, 26.399999618530273, 29.0, 28.19999885559082, 30.0, 23.0, 22.399999618530273, 28.0, 31.0, 38.0, 39.39999771118164, 36.0, 34.39999771118164, 25.599998474121094, 25.0] +----------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 20.67+/- 1.287 (max: 39.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 29.17+/- 1.318 (max: 39.4) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 11.4+/- 1.204 (max: 16.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 21.42+/- 1.341 (max: 33.0) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 13.38+/- 0.3369 (max: 16.97) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 13.57+/- 0.5546 (max: 16.97) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.36+/- 0.7336 (max: 15.15) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 14.21+/- 0.3214 (max: 16.84) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.2894+/- 0.02954 (max: 0.83) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.4837+/- 0.04042 (max: 0.83) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.1019+/- 0.01708 (max: 0.19) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.2825+/- 0.03753 (max: 0.65) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 2.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 22.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 13.8 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 6.499 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 10.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.499 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.79 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.31 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.08 | +----------------------------------------------------------------------------------------------- +Evaluating ACCEL_SoftMoE_SEED1 against population in Overcooked-ForcedCoord6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [0.7999999523162842, 0.19999998807907104, 0.5999999642372131, 0.19999998807907104, 0.0, 0.19999998807907104, 1.5999999046325684, 0.3999999761581421, 1.399999976158142, 1.0, 0.3999999761581421, 0.3999999761581421, 0.5999999642372131, 0.0, 1.0, 0.0, 0.3999999761581421, 0.0, 1.5999999046325684, 0.0, 1.399999976158142, 0.5999999642372131, 0.0, 0.3999999761581421, 0.3999999761581421, 0.0, 1.399999976158142, 0.19999998807907104, 1.0, 0.19999998807907104, 1.0, 0.0, 1.0, 0.3999999761581421, 0.3999999761581421, 0.0, 0.3999999761581421, 0.0, 0.19999998807907104, 0.19999998807907104, 0.7999999523162842, 0.19999998807907104, 1.1999999284744263, 0.19999998807907104, 0.7999999523162842, 0.5999999642372131, 1.5999999046325684, 0.5999999642372131] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [0.7999999523162842, 0.19999998807907104, 1.5999999046325684, 0.3999999761581421, 0.5999999642372131, 0.0, 1.5999999046325684, 0.0, 0.3999999761581421, 0.0, 1.0, 0.0, 0.3999999761581421, 0.0, 1.1999999284744263, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [0.5999999642372131, 0.19999998807907104, 1.399999976158142, 1.0, 1.0, 0.0, 1.399999976158142, 0.5999999642372131, 1.399999976158142, 0.19999998807907104, 1.0, 0.3999999761581421, 0.19999998807907104, 0.19999998807907104, 0.7999999523162842, 0.5999999642372131] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [0.0, 0.19999998807907104, 0.3999999761581421, 0.3999999761581421, 0.3999999761581421, 0.0, 0.0, 0.3999999761581421, 1.0, 0.19999998807907104, 0.3999999761581421, 0.0, 0.7999999523162842, 0.19999998807907104, 1.5999999046325684, 0.5999999642372131] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.5417+/- 0.0712 (max: 1.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.4125+/- 0.1072 (max: 1.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.525+/- 0.1401 (max: 1.6) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.6875+/- 0.1183 (max: 1.4) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 2.71+/- 0.251 (max: 5.426) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 2.318+/- 0.4111 (max: 5.426) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 2.479+/- 0.5021 (max: 5.426) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 3.331+/- 0.3624 (max: 5.103) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0+/- 0.0 (max: 0.0) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------ +Evaluating ACCEL_SoftMoE_SEED1 against population in Overcooked-CounterCircuit6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [9.800000190734863, 7.599999904632568, 13.199999809265137, 12.399999618530273, 12.59999942779541, 12.799999237060547, 6.399999618530273, 8.0, 11.199999809265137, 9.59999942779541, 3.0, 1.399999976158142, 5.599999904632568, 3.3999998569488525, 14.0, 12.399999618530273, 16.600000381469727, 9.0, 11.399999618530273, 8.199999809265137, 20.0, 17.0, 11.399999618530273, 6.0, 10.399999618530273, 8.800000190734863, 21.19999885559082, 18.0, 11.59999942779541, 6.199999809265137, 11.800000190734863, 12.199999809265137, 18.799999237060547, 12.399999618530273, 11.199999809265137, 5.599999904632568, 9.199999809265137, 5.199999809265137, 13.799999237060547, 11.199999809265137, 6.799999713897705, 6.399999618530273, 11.800000190734863, 8.59999942779541, 14.799999237060547, 15.199999809265137, 9.0, 4.199999809265137] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [9.800000190734863, 7.599999904632568, 6.399999618530273, 8.0, 5.599999904632568, 3.3999998569488525, 11.399999618530273, 8.199999809265137, 10.399999618530273, 8.800000190734863, 11.800000190734863, 12.199999809265137, 9.199999809265137, 5.199999809265137, 11.800000190734863, 8.59999942779541] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [13.199999809265137, 12.399999618530273, 11.199999809265137, 9.59999942779541, 14.0, 12.399999618530273, 20.0, 17.0, 21.19999885559082, 18.0, 18.799999237060547, 12.399999618530273, 13.799999237060547, 11.199999809265137, 14.799999237060547, 15.199999809265137] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [12.59999942779541, 12.799999237060547, 3.0, 1.399999976158142, 16.600000381469727, 9.0, 11.399999618530273, 6.0, 11.59999942779541, 6.199999809265137, 11.199999809265137, 5.599999904632568, 6.799999713897705, 6.399999618530273, 9.0, 4.199999809265137] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 10.57+/- 0.6443 (max: 21.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 8.362+/- 1.026 (max: 16.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 8.65+/- 0.6464 (max: 12.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 14.7+/- 0.8503 (max: 21.2) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 11.37+/- 0.3286 (max: 15.9) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 9.852+/- 0.4803 (max: 12.54) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 10.64+/- 0.3431 (max: 12.99) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 13.62+/- 0.3743 (max: 15.9) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.06312+/- 0.009956 (max: 0.25) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.02375+/- 0.00875 (max: 0.11) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.02875+/- 0.006575 (max: 0.09) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.1369+/- 0.01635 (max: 0.25) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 1.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 1.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 3.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 9.6 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 5.103 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 5.103 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 7.513 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 10.39 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.03 | +------------------------------------------------------------------------------------------------------ +Evaluating ACCEL_SoftMoE_SEED1 against population in Overcooked-AsymmAdvantages6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [0.5999999642372131, 4.199999809265137, 0.7999999523162842, 15.199999809265137, 0.7999999523162842, 3.5999999046325684, 1.0, 5.799999713897705, 1.399999976158142, 36.39999771118164, 0.3999999761581421, 6.399999618530273, 0.7999999523162842, 37.599998474121094, 0.5999999642372131, 60.0, 1.0, 11.59999942779541, 0.3999999761581421, 17.399999618530273, 1.5999999046325684, 12.799999237060547, 2.0, 8.0, 0.5999999642372131, 4.599999904632568, 1.0, 22.19999885559082, 0.19999998807907104, 8.0, 0.3999999761581421, 4.199999809265137, 0.3999999761581421, 29.19999885559082, 0.7999999523162842, 4.599999904632568, 0.5999999642372131, 5.799999713897705, 1.7999999523162842, 21.19999885559082, 0.5999999642372131, 9.800000190734863, 0.7999999523162842, 6.399999618530273, 0.7999999523162842, 23.0, 0.7999999523162842, 4.400000095367432] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [0.5999999642372131, 4.199999809265137, 1.0, 5.799999713897705, 0.7999999523162842, 37.599998474121094, 0.3999999761581421, 17.399999618530273, 0.5999999642372131, 4.599999904632568, 0.3999999761581421, 4.199999809265137, 0.5999999642372131, 5.799999713897705, 0.7999999523162842, 6.399999618530273] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [0.7999999523162842, 15.199999809265137, 1.399999976158142, 36.39999771118164, 0.5999999642372131, 60.0, 1.5999999046325684, 12.799999237060547, 1.0, 22.19999885559082, 0.3999999761581421, 29.19999885559082, 1.7999999523162842, 21.19999885559082, 0.7999999523162842, 23.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [0.7999999523162842, 3.5999999046325684, 0.3999999761581421, 6.399999618530273, 1.0, 11.59999942779541, 2.0, 8.0, 0.19999998807907104, 8.0, 0.7999999523162842, 4.599999904632568, 0.5999999642372131, 9.800000190734863, 0.7999999523162842, 4.400000095367432] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 7.971+/- 1.755 (max: 60.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 3.937+/- 0.9409 (max: 11.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 5.7+/- 2.385 (max: 37.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 14.27+/- 4.279 (max: 60.0) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 9.27+/- 0.9981 (max: 29.66) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 7.492+/- 1.044 (max: 15.79) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 7.73+/- 1.436 (max: 24.05) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 12.59+/- 2.261 (max: 29.66) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.09062+/- 0.02645 (max: 0.85) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.02375+/- 0.00841 (max: 0.12) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.05625+/- 0.03769 (max: 0.58) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1919+/- 0.06343 (max: 0.85) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.4 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 0.4 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 2.8 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 2.8 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating ACCEL_SoftMoE_SEED1 against population in Overcooked-CrampedRoom6_9 for xpid plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [11.0, 9.800000190734863, 68.4000015258789, 68.0, 102.39999389648438, 99.19999694824219, 16.399999618530273, 15.59999942779541, 74.5999984741211, 74.4000015258789, 115.5999984741211, 114.19999694824219, 19.19999885559082, 17.399999618530273, 95.79999542236328, 100.4000015258789, 114.79999542236328, 113.0, 16.799999237060547, 18.600000381469727, 95.4000015258789, 95.5999984741211, 92.0, 93.5999984741211, 9.199999809265137, 9.199999809265137, 75.4000015258789, 75.79999542236328, 108.5999984741211, 110.79999542236328, 18.399999618530273, 19.19999885559082, 95.5999984741211, 96.4000015258789, 117.0, 121.0, 15.59999942779541, 15.59999942779541, 89.0, 85.0, 113.0, 111.39999389648438, 16.600000381469727, 20.799999237060547, 97.4000015258789, 94.79999542236328, 110.19999694824219, 107.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [11.0, 9.800000190734863, 16.399999618530273, 15.59999942779541, 19.19999885559082, 17.399999618530273, 16.799999237060547, 18.600000381469727, 9.199999809265137, 9.199999809265137, 18.399999618530273, 19.19999885559082, 15.59999942779541, 15.59999942779541, 16.600000381469727, 20.799999237060547] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [68.4000015258789, 68.0, 74.5999984741211, 74.4000015258789, 95.79999542236328, 100.4000015258789, 95.4000015258789, 95.5999984741211, 75.4000015258789, 75.79999542236328, 95.5999984741211, 96.4000015258789, 89.0, 85.0, 97.4000015258789, 94.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [102.39999389648438, 99.19999694824219, 115.5999984741211, 114.19999694824219, 114.79999542236328, 113.0, 92.0, 93.5999984741211, 108.5999984741211, 110.79999542236328, 117.0, 121.0, 113.0, 111.39999389648438, 110.19999694824219, 107.0] +------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 70.32+/- 5.926 (max: 121.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 109.0+/- 2.066 (max: 121.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 15.59+/- 0.9417 (max: 20.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 86.37+/- 2.897 (max: 100.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 21.84+/- 0.9275 (max: 33.95) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 25.13+/- 1.029 (max: 31.98) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.99+/- 0.4335 (max: 16.95) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 26.4+/- 0.7863 (max: 33.95) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.7017+/- 0.05652 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9887+/- 0.00427 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1587+/- 0.02091 (max: 0.3) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9575+/- 0.008342 (max: 1.0) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 9.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 92.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 9.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 68.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.74 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 16.09 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.74 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 21.61 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.93 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.89 | +------------------------------------------------------------------------------------------------- +Evaluating ACCEL_SoftMoE_SEED2 against population in Overcooked-CoordRing6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [3.3999998569488525, 4.400000095367432, 15.59999942779541, 17.399999618530273, 19.399999618530273, 20.399999618530273, 14.199999809265137, 13.799999237060547, 25.599998474121094, 25.599998474121094, 17.799999237060547, 20.0, 12.799999237060547, 13.799999237060547, 22.399999618530273, 20.600000381469727, 18.600000381469727, 20.399999618530273, 11.399999618530273, 12.199999809265137, 12.799999237060547, 16.399999618530273, 15.199999809265137, 17.600000381469727, 13.399999618530273, 13.199999809265137, 18.399999618530273, 19.0, 20.799999237060547, 21.19999885559082, 11.0, 13.0, 19.799999237060547, 21.399999618530273, 32.0, 29.399999618530273, 3.1999998092651367, 3.0, 13.799999237060547, 17.19999885559082, 26.0, 29.799999237060547, 12.799999237060547, 15.59999942779541, 10.199999809265137, 15.199999809265137, 21.0, 24.19999885559082] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [3.3999998569488525, 4.400000095367432, 14.199999809265137, 13.799999237060547, 12.799999237060547, 13.799999237060547, 11.399999618530273, 12.199999809265137, 13.399999618530273, 13.199999809265137, 11.0, 13.0, 3.1999998092651367, 3.0, 12.799999237060547, 15.59999942779541] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [15.59999942779541, 17.399999618530273, 25.599998474121094, 25.599998474121094, 22.399999618530273, 20.600000381469727, 12.799999237060547, 16.399999618530273, 18.399999618530273, 19.0, 19.799999237060547, 21.399999618530273, 13.799999237060547, 17.19999885559082, 10.199999809265137, 15.199999809265137] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [19.399999618530273, 20.399999618530273, 17.799999237060547, 20.0, 18.600000381469727, 20.399999618530273, 15.199999809265137, 17.600000381469727, 20.799999237060547, 21.19999885559082, 32.0, 29.399999618530273, 26.0, 29.799999237060547, 21.0, 24.19999885559082] +---------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CoordRing6_9 | 17.01+/- 0.9417 (max: 32.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 22.11+/- 1.208 (max: 32.0) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 10.7+/- 1.108 (max: 15.6) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 18.21+/- 1.079 (max: 25.6) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 12.48+/- 0.2989 (max: 16.31) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 12.41+/- 0.4374 (max: 15.35) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 11.54+/- 0.6308 (max: 14.67) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 13.5+/- 0.349 (max: 16.31) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.174+/- 0.02075 (max: 0.65) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.2569+/- 0.0447 (max: 0.65) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.07+/- 0.01278 (max: 0.14) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.195+/- 0.02603 (max: 0.39) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 3.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 15.2 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 3.0 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 10.2 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 7.141 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 10.16 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 7.141 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 11.49 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.07 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.04 | +---------------------------------------------------------------------------------------------- +Evaluating ACCEL_SoftMoE_SEED2 against population in Overcooked-ForcedCoord6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [3.1999998092651367, 0.19999998807907104, 4.0, 0.3999999761581421, 3.799999952316284, 0.7999999523162842, 5.0, 0.0, 3.3999998569488525, 0.7999999523162842, 3.1999998092651367, 0.5999999642372131, 2.799999952316284, 0.19999998807907104, 1.399999976158142, 0.7999999523162842, 0.7999999523162842, 0.0, 4.0, 0.19999998807907104, 4.599999904632568, 0.19999998807907104, 2.799999952316284, 1.399999976158142, 2.200000047683716, 0.19999998807907104, 2.799999952316284, 0.7999999523162842, 4.0, 1.0, 3.799999952316284, 0.19999998807907104, 2.5999999046325684, 0.0, 1.7999999523162842, 0.5999999642372131, 3.0, 0.0, 3.3999998569488525, 0.19999998807907104, 2.0, 0.3999999761581421, 5.400000095367432, 0.0, 4.0, 0.3999999761581421, 6.0, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [3.1999998092651367, 0.19999998807907104, 5.0, 0.0, 2.799999952316284, 0.19999998807907104, 4.0, 0.19999998807907104, 2.200000047683716, 0.19999998807907104, 3.799999952316284, 0.19999998807907104, 3.0, 0.0, 5.400000095367432, 0.0] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [4.0, 0.3999999761581421, 3.3999998569488525, 0.7999999523162842, 1.399999976158142, 0.7999999523162842, 4.599999904632568, 0.19999998807907104, 2.799999952316284, 0.7999999523162842, 2.5999999046325684, 0.0, 3.3999998569488525, 0.19999998807907104, 4.0, 0.3999999761581421] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [3.799999952316284, 0.7999999523162842, 3.1999998092651367, 0.5999999642372131, 0.7999999523162842, 0.0, 2.799999952316284, 1.399999976158142, 4.0, 1.0, 1.7999999523162842, 0.5999999642372131, 2.0, 0.3999999761581421, 6.0, 0.0] +----------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 1.862+/- 0.2511 (max: 6.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 1.825+/- 0.4262 (max: 6.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 1.9+/- 0.4956 (max: 5.4) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 1.862+/- 0.4065 (max: 4.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 4.794+/- 0.4211 (max: 9.165) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 4.871+/- 0.6694 (max: 9.165) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 4.51+/- 0.8705 (max: 8.879) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 5.002+/- 0.6724 (max: 8.485) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.000625+/- 0.0003531 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.000625+/- 0.000625 (max: 0.01) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +----------------------------------------------------------------------------------------------------- +Evaluating ACCEL_SoftMoE_SEED2 against population in Overcooked-CounterCircuit6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [8.0, 6.599999904632568, 14.199999809265137, 10.59999942779541, 15.199999809265137, 13.59999942779541, 8.0, 6.0, 12.399999618530273, 10.59999942779541, 2.3999998569488525, 0.7999999523162842, 2.200000047683716, 2.200000047683716, 9.399999618530273, 6.0, 10.59999942779541, 6.599999904632568, 7.399999618530273, 8.800000190734863, 17.19999885559082, 15.59999942779541, 7.799999713897705, 6.199999809265137, 8.199999809265137, 8.800000190734863, 18.799999237060547, 13.59999942779541, 7.199999809265137, 4.799999713897705, 9.199999809265137, 10.399999618530273, 14.799999237060547, 11.59999942779541, 9.199999809265137, 7.199999809265137, 6.799999713897705, 6.399999618530273, 16.19999885559082, 10.800000190734863, 7.799999713897705, 4.199999809265137, 9.0, 8.399999618530273, 18.600000381469727, 12.59999942779541, 7.0, 4.599999904632568] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [8.0, 6.599999904632568, 8.0, 6.0, 2.200000047683716, 2.200000047683716, 7.399999618530273, 8.800000190734863, 8.199999809265137, 8.800000190734863, 9.199999809265137, 10.399999618530273, 6.799999713897705, 6.399999618530273, 9.0, 8.399999618530273] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [14.199999809265137, 10.59999942779541, 12.399999618530273, 10.59999942779541, 9.399999618530273, 6.0, 17.19999885559082, 15.59999942779541, 18.799999237060547, 13.59999942779541, 14.799999237060547, 11.59999942779541, 16.19999885559082, 10.800000190734863, 18.600000381469727, 12.59999942779541] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [15.199999809265137, 13.59999942779541, 2.3999998569488525, 0.7999999523162842, 10.59999942779541, 6.599999904632568, 7.799999713897705, 6.199999809265137, 7.199999809265137, 4.799999713897705, 9.199999809265137, 7.199999809265137, 7.799999713897705, 4.199999809265137, 7.0, 4.599999904632568] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 9.262+/- 0.6191 (max: 18.8) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 7.2+/- 0.9313 (max: 15.2) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 7.275+/- 0.5733 (max: 10.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 13.31+/- 0.8754 (max: 18.8) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 11.28+/- 0.4246 (max: 18.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 9.457+/- 0.53 (max: 13.6) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 10.17+/- 0.4309 (max: 12.8) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 14.2+/- 0.5951 (max: 18.6) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.06229+/- 0.01171 (max: 0.32) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.01812+/- 0.008814 (max: 0.14) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.02312+/- 0.005456 (max: 0.08) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.1456+/- 0.02204 (max: 0.32) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 0.8 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 0.8 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 2.2 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 6.0 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 3.919 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 3.919 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 6.258 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 9.165 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating ACCEL_SoftMoE_SEED2 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [2.5999999046325684, 7.199999809265137, 2.3999998569488525, 19.399999618530273, 3.1999998092651367, 8.199999809265137, 2.799999952316284, 9.199999809265137, 2.799999952316284, 43.20000076293945, 2.799999952316284, 7.399999618530273, 3.1999998092651367, 38.0, 3.1999998092651367, 57.599998474121094, 3.1999998092651367, 11.199999809265137, 3.1999998092651367, 19.799999237060547, 3.3999998569488525, 13.399999618530273, 4.0, 6.199999809265137, 2.5999999046325684, 4.799999713897705, 2.799999952316284, 23.0, 2.200000047683716, 7.0, 2.5999999046325684, 7.799999713897705, 4.0, 29.599998474121094, 3.5999999046325684, 6.0, 2.799999952316284, 8.399999618530273, 2.5999999046325684, 21.600000381469727, 3.799999952316284, 8.0, 2.0, 7.599999904632568, 2.5999999046325684, 26.19999885559082, 4.199999809265137, 4.0] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [2.5999999046325684, 7.199999809265137, 2.799999952316284, 9.199999809265137, 3.1999998092651367, 38.0, 3.1999998092651367, 19.799999237060547, 2.5999999046325684, 4.799999713897705, 2.5999999046325684, 7.799999713897705, 2.799999952316284, 8.399999618530273, 2.0, 7.599999904632568] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [2.3999998569488525, 19.399999618530273, 2.799999952316284, 43.20000076293945, 3.1999998092651367, 57.599998474121094, 3.3999998569488525, 13.399999618530273, 2.799999952316284, 23.0, 4.0, 29.599998474121094, 2.5999999046325684, 21.600000381469727, 2.5999999046325684, 26.19999885559082] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [3.1999998092651367, 8.199999809265137, 2.799999952316284, 7.399999618530273, 3.1999998092651367, 11.199999809265137, 4.0, 6.199999809265137, 2.200000047683716, 7.0, 3.5999999046325684, 6.0, 3.799999952316284, 8.0, 4.199999809265137, 4.0] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 9.737+/- 1.712 (max: 57.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 5.312+/- 0.6242 (max: 11.2) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 7.787+/- 2.305 (max: 38.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 16.11+/- 4.192 (max: 57.6) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 12.02+/- 0.9337 (max: 31.69) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 9.808+/- 0.547 (max: 13.66) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 11.32+/- 1.618 (max: 31.69) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 14.92+/- 2.087 (max: 31.66) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.1087+/- 0.02705 (max: 0.81) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.03312+/- 0.006875 (max: 0.09) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.08+/- 0.03578 (max: 0.55) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.2131+/- 0.06624 (max: 0.81) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 2.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 2.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 2.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 2.4 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 6.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 6.258 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 6.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 6.499 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating ACCEL_SoftMoE_SEED2 against population in Overcooked-CrampedRoom6_9 for xpid SEED_2_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [10.59999942779541, 10.0, 66.4000015258789, 65.79999542236328, 86.4000015258789, 90.19999694824219, 13.799999237060547, 15.199999809265137, 71.0, 70.19999694824219, 96.5999984741211, 97.19999694824219, 16.600000381469727, 15.399999618530273, 90.4000015258789, 83.79999542236328, 95.19999694824219, 93.19999694824219, 16.399999618530273, 14.0, 84.0, 79.79999542236328, 64.5999984741211, 72.5999984741211, 7.199999809265137, 7.399999618530273, 70.79999542236328, 67.5999984741211, 91.0, 91.0, 15.0, 16.19999885559082, 80.79999542236328, 84.79999542236328, 104.5999984741211, 100.19999694824219, 15.0, 14.399999618530273, 81.0, 78.79999542236328, 96.79999542236328, 93.0, 18.600000381469727, 16.0, 90.19999694824219, 79.4000015258789, 96.5999984741211, 97.5999984741211] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [10.59999942779541, 10.0, 13.799999237060547, 15.199999809265137, 16.600000381469727, 15.399999618530273, 16.399999618530273, 14.0, 7.199999809265137, 7.399999618530273, 15.0, 16.19999885559082, 15.0, 14.399999618530273, 18.600000381469727, 16.0] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [66.4000015258789, 65.79999542236328, 71.0, 70.19999694824219, 90.4000015258789, 83.79999542236328, 84.0, 79.79999542236328, 70.79999542236328, 67.5999984741211, 80.79999542236328, 84.79999542236328, 81.0, 78.79999542236328, 90.19999694824219, 79.4000015258789] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [86.4000015258789, 90.19999694824219, 96.5999984741211, 97.19999694824219, 95.19999694824219, 93.19999694824219, 64.5999984741211, 72.5999984741211, 91.0, 91.0, 104.5999984741211, 100.19999694824219, 96.79999542236328, 93.0, 96.5999984741211, 97.5999984741211] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 61.11+/- 5.062 (max: 104.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 91.67+/- 2.52 (max: 104.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 13.86+/- 0.8296 (max: 18.6) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 77.8+/- 2.035 (max: 90.4) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 24.45+/- 1.174 (max: 36.21) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 29.64+/- 0.6716 (max: 33.11) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 13.65+/- 0.4103 (max: 16.01) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 30.06+/- 0.8101 (max: 36.21) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.6667+/- 0.05556 (max: 0.99) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.9556+/- 0.008849 (max: 0.99) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.1325+/- 0.0159 (max: 0.23) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9119+/- 0.009274 (max: 0.95) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 7.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 64.6 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 7.2 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 65.8 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 10.45 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 24.41 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 10.45 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 24.8 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.86 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.02 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.83 | +-------------------------------------------------------------------------------------------------- +Evaluating ACCEL_SoftMoE_SEED3 against population in Overcooked-CoordRing6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CoordRing6_9, v [3.5999999046325684, 4.400000095367432, 20.0, 16.600000381469727, 24.19999885559082, 24.0, 13.799999237060547, 16.399999618530273, 30.19999885559082, 35.0, 25.19999885559082, 25.399999618530273, 14.59999942779541, 11.59999942779541, 23.19999885559082, 25.399999618530273, 26.19999885559082, 28.399999618530273, 18.19999885559082, 12.799999237060547, 19.0, 20.600000381469727, 21.600000381469727, 21.0, 16.399999618530273, 14.799999237060547, 22.399999618530273, 24.600000381469727, 23.19999885559082, 28.799999237060547, 12.799999237060547, 12.399999618530273, 25.0, 23.799999237060547, 30.0, 36.20000076293945, 3.1999998092651367, 2.3999998569488525, 18.399999618530273, 20.19999885559082, 26.599998474121094, 30.599998474121094, 12.799999237060547, 14.0, 14.59999942779541, 17.19999885559082, 18.600000381469727, 25.19999885559082] +k eval/a0:test_return:Overcooked-CoordRing6_9:low, v [3.5999999046325684, 4.400000095367432, 13.799999237060547, 16.399999618530273, 14.59999942779541, 11.59999942779541, 18.19999885559082, 12.799999237060547, 16.399999618530273, 14.799999237060547, 12.799999237060547, 12.399999618530273, 3.1999998092651367, 2.3999998569488525, 12.799999237060547, 14.0] +k eval/a0:test_return:Overcooked-CoordRing6_9:mid, v [20.0, 16.600000381469727, 30.19999885559082, 35.0, 23.19999885559082, 25.399999618530273, 19.0, 20.600000381469727, 22.399999618530273, 24.600000381469727, 25.0, 23.799999237060547, 18.399999618530273, 20.19999885559082, 14.59999942779541, 17.19999885559082] +k eval/a0:test_return:Overcooked-CoordRing6_9:high, v [24.19999885559082, 24.0, 25.19999885559082, 25.399999618530273, 26.19999885559082, 28.399999618530273, 21.600000381469727, 21.0, 23.19999885559082, 28.799999237060547, 30.0, 36.20000076293945, 26.599998474121094, 30.599998474121094, 18.600000381469727, 25.19999885559082] +------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CoordRing6_9 | 19.91+/- 1.128 (max: 36.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:high | 25.95+/- 1.063 (max: 36.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:low | 11.51+/- 1.283 (max: 18.2) | +| eval/a0:test_return:Overcooked-CoordRing6_9:mid | 22.26+/- 1.3 (max: 35.0) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9 | 13.51+/- 0.3541 (max: 19.29) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 13.57+/- 0.58 (max: 19.29) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 12.26+/- 0.7193 (max: 15.17) | +| eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 14.71+/- 0.3518 (max: 17.37) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.2606+/- 0.02445 (max: 0.68) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.38+/- 0.03378 (max: 0.67) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.09625+/- 0.01622 (max: 0.23) | +| eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.3056+/- 0.03643 (max: 0.68) | +| min:eval/a0:test_return:Overcooked-CoordRing6_9 | 2.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:high | 18.6 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:low | 2.4 | +| min:eval/a0:test_return:Overcooked-CoordRing6_9:mid | 14.6 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9 | 6.499 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:high | 10.65 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:low | 6.499 | +| min:eval/a0:test_return_std:Overcooked-CoordRing6_9:mid | 12.6 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:high | 0.17 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CoordRing6_9:mid | 0.1 | +------------------------------------------------------------------------------------------------ +Evaluating ACCEL_SoftMoE_SEED3 against population in Overcooked-ForcedCoord6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-ForcedCoord6_9, v [6.399999618530273, 0.3999999761581421, 6.599999904632568, 1.0, 3.1999998092651367, 0.3999999761581421, 7.799999713897705, 0.0, 6.399999618530273, 0.5999999642372131, 4.400000095367432, 0.5999999642372131, 5.199999809265137, 2.200000047683716, 4.599999904632568, 1.399999976158142, 1.5999999046325684, 1.0, 7.199999809265137, 0.3999999761581421, 5.599999904632568, 0.0, 5.799999713897705, 0.3999999761581421, 4.0, 0.3999999761581421, 5.599999904632568, 0.19999998807907104, 8.0, 0.3999999761581421, 5.799999713897705, 0.3999999761581421, 5.0, 0.0, 3.5999999046325684, 0.3999999761581421, 2.799999952316284, 0.19999998807907104, 4.0, 1.5999999046325684, 3.5999999046325684, 0.7999999523162842, 8.0, 0.19999998807907104, 7.599999904632568, 0.5999999642372131, 9.800000190734863, 0.3999999761581421] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:low, v [6.399999618530273, 0.3999999761581421, 7.799999713897705, 0.0, 5.199999809265137, 2.200000047683716, 7.199999809265137, 0.3999999761581421, 4.0, 0.3999999761581421, 5.799999713897705, 0.3999999761581421, 2.799999952316284, 0.19999998807907104, 8.0, 0.19999998807907104] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:mid, v [6.599999904632568, 1.0, 6.399999618530273, 0.5999999642372131, 4.599999904632568, 1.399999976158142, 5.599999904632568, 0.0, 5.599999904632568, 0.19999998807907104, 5.0, 0.0, 4.0, 1.5999999046325684, 7.599999904632568, 0.5999999642372131] +k eval/a0:test_return:Overcooked-ForcedCoord6_9:high, v [3.1999998092651367, 0.3999999761581421, 4.400000095367432, 0.5999999642372131, 1.5999999046325684, 1.0, 5.799999713897705, 0.3999999761581421, 8.0, 0.3999999761581421, 3.5999999046325684, 0.3999999761581421, 3.5999999046325684, 0.7999999523162842, 9.800000190734863, 0.3999999761581421] +---------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-ForcedCoord6_9 | 3.054+/- 0.4147 (max: 9.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 2.775+/- 0.7398 (max: 9.8) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 3.212+/- 0.7708 (max: 8.0) | +| eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 3.175+/- 0.6836 (max: 7.6) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 5.863+/- 0.4681 (max: 10.2) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 5.686+/- 0.6958 (max: 9.998) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 5.847+/- 0.8881 (max: 10.2) | +| eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 6.056+/- 0.8845 (max: 10.11) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.00125+/- 0.0004824 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0+/- 0.0 (max: 0.0) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.00125+/- 0.0008539 (max: 0.01) | +| eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0025+/- 0.001118 (max: 0.01) | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:high | 0.4 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:high | 2.8 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_return_std:Overcooked-ForcedCoord6_9:mid | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-ForcedCoord6_9:mid | 0.0 | +---------------------------------------------------------------------------------------------------- +Evaluating ACCEL_SoftMoE_SEED3 against population in Overcooked-CounterCircuit6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CounterCircuit6_9, v [10.59999942779541, 10.800000190734863, 21.799999237060547, 13.59999942779541, 15.59999942779541, 13.199999809265137, 6.599999904632568, 6.599999904632568, 11.0, 11.199999809265137, 5.799999713897705, 1.399999976158142, 1.5999999046325684, 1.5999999046325684, 9.199999809265137, 9.199999809265137, 10.399999618530273, 9.0, 8.199999809265137, 7.399999618530273, 22.399999618530273, 17.799999237060547, 13.199999809265137, 7.399999618530273, 10.800000190734863, 8.399999618530273, 19.19999885559082, 15.799999237060547, 8.800000190734863, 4.199999809265137, 12.0, 9.800000190734863, 19.0, 13.59999942779541, 10.399999618530273, 8.0, 7.599999904632568, 7.199999809265137, 15.0, 13.799999237060547, 12.799999237060547, 5.0, 10.199999809265137, 9.0, 20.399999618530273, 17.600000381469727, 12.799999237060547, 4.199999809265137] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:low, v [10.59999942779541, 10.800000190734863, 6.599999904632568, 6.599999904632568, 1.5999999046325684, 1.5999999046325684, 8.199999809265137, 7.399999618530273, 10.800000190734863, 8.399999618530273, 12.0, 9.800000190734863, 7.599999904632568, 7.199999809265137, 10.199999809265137, 9.0] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:mid, v [21.799999237060547, 13.59999942779541, 11.0, 11.199999809265137, 9.199999809265137, 9.199999809265137, 22.399999618530273, 17.799999237060547, 19.19999885559082, 15.799999237060547, 19.0, 13.59999942779541, 15.0, 13.799999237060547, 20.399999618530273, 17.600000381469727] +k eval/a0:test_return:Overcooked-CounterCircuit6_9:high, v [15.59999942779541, 13.199999809265137, 5.799999713897705, 1.399999976158142, 10.399999618530273, 9.0, 13.199999809265137, 7.399999618530273, 8.800000190734863, 4.199999809265137, 10.399999618530273, 8.0, 12.799999237060547, 5.0, 12.799999237060547, 4.199999809265137] +------------------------------------------------------------------------------------------------------ +| eval/a0:test_return:Overcooked-CounterCircuit6_9 | 10.86+/- 0.7338 (max: 22.4) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 8.887+/- 1.011 (max: 15.6) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 8.025+/- 0.7492 (max: 12.0) | +| eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 15.66+/- 1.071 (max: 22.4) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 11.78+/- 0.4498 (max: 19.79) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 10.08+/- 0.437 (max: 12.07) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 10.35+/- 0.5525 (max: 12.65) | +| eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 14.91+/- 0.6494 (max: 19.79) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.07833+/- 0.01431 (max: 0.38) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.02625+/- 0.006575 (max: 0.08) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.02625+/- 0.006382 (max: 0.07) | +| eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.1825+/- 0.02747 (max: 0.38) | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9 | 1.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:high | 1.4 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:low | 1.6 | +| min:eval/a0:test_return:Overcooked-CounterCircuit6_9:mid | 9.2 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9 | 5.103 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:high | 5.103 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:low | 5.426 | +| min:eval/a0:test_return_std:Overcooked-CounterCircuit6_9:mid | 9.968 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-CounterCircuit6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------ +Evaluating ACCEL_SoftMoE_SEED3 against population in Overcooked-AsymmAdvantages6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9, v [1.0, 6.0, 1.7999999523162842, 14.399999618530273, 2.200000047683716, 3.0, 1.5999999046325684, 9.199999809265137, 1.399999976158142, 42.39999771118164, 2.3999998569488525, 3.1999998092651367, 0.7999999523162842, 37.0, 1.0, 56.0, 1.0, 11.59999942779541, 1.5999999046325684, 19.799999237060547, 3.0, 9.59999942779541, 2.200000047683716, 4.799999713897705, 0.5999999642372131, 3.3999998569488525, 1.0, 18.19999885559082, 1.1999999284744263, 5.799999713897705, 0.7999999523162842, 6.599999904632568, 1.1999999284744263, 25.0, 2.0, 5.400000095367432, 0.7999999523162842, 5.400000095367432, 1.7999999523162842, 24.399999618530273, 1.7999999523162842, 8.800000190734863, 0.19999998807907104, 9.199999809265137, 1.1999999284744263, 22.600000381469727, 2.3999998569488525, 2.200000047683716] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low, v [1.0, 6.0, 1.5999999046325684, 9.199999809265137, 0.7999999523162842, 37.0, 1.5999999046325684, 19.799999237060547, 0.5999999642372131, 3.3999998569488525, 0.7999999523162842, 6.599999904632568, 0.7999999523162842, 5.400000095367432, 0.19999998807907104, 9.199999809265137] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid, v [1.7999999523162842, 14.399999618530273, 1.399999976158142, 42.39999771118164, 1.0, 56.0, 3.0, 9.59999942779541, 1.0, 18.19999885559082, 1.1999999284744263, 25.0, 1.7999999523162842, 24.399999618530273, 1.1999999284744263, 22.600000381469727] +k eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high, v [2.200000047683716, 3.0, 2.3999998569488525, 3.1999998092651367, 1.0, 11.59999942779541, 2.200000047683716, 4.799999713897705, 1.1999999284744263, 5.799999713897705, 2.0, 5.400000095367432, 1.7999999523162842, 8.800000190734863, 2.3999998569488525, 2.200000047683716] +------------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 8.104+/- 1.707 (max: 56.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 3.75+/- 0.7297 (max: 11.6) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 6.5+/- 2.401 (max: 37.0) | +| eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 14.06+/- 4.16 (max: 56.0) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 9.987+/- 0.9854 (max: 30.59) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 7.933+/- 0.6595 (max: 12.43) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 8.856+/- 1.681 (max: 26.89) | +| eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 13.17+/- 2.196 (max: 30.59) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.08833+/- 0.0261 (max: 0.75) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.01812+/- 0.005643 (max: 0.07) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.06938+/- 0.04139 (max: 0.62) | +| eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.1775+/- 0.06143 (max: 0.75) | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9 | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:high | 1.0 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:low | 0.2 | +| min:eval/a0:test_return:Overcooked-AsymmAdvantages6_9:mid | 1.0 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9 | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:high | 4.359 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:low | 1.99 | +| min:eval/a0:test_return_std:Overcooked-AsymmAdvantages6_9:mid | 4.359 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9 | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:high | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:low | 0.0 | +| min:eval/a0:test_solved_rate:Overcooked-AsymmAdvantages6_9:mid | 0.0 | +------------------------------------------------------------------------------------------------------- +Evaluating ACCEL_SoftMoE_SEED3 against population in Overcooked-CrampedRoom6_9 for xpid SEED_3_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 +k eval/a0:test_return:Overcooked-CrampedRoom6_9, v [12.799999237060547, 12.0, 70.4000015258789, 67.0, 97.0, 94.5999984741211, 17.799999237060547, 14.799999237060547, 75.19999694824219, 69.0, 104.0, 102.79999542236328, 20.0, 21.399999618530273, 89.79999542236328, 93.0, 104.0, 110.0, 21.0, 23.799999237060547, 84.0, 81.5999984741211, 76.4000015258789, 78.5999984741211, 9.0, 12.399999618530273, 77.19999694824219, 67.19999694824219, 98.4000015258789, 96.5999984741211, 19.0, 20.799999237060547, 84.5999984741211, 86.0, 105.5999984741211, 97.79999542236328, 19.799999237060547, 18.399999618530273, 80.4000015258789, 75.79999542236328, 95.4000015258789, 104.19999694824219, 22.799999237060547, 21.19999885559082, 86.79999542236328, 80.79999542236328, 99.0, 98.4000015258789] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:low, v [12.799999237060547, 12.0, 17.799999237060547, 14.799999237060547, 20.0, 21.399999618530273, 21.0, 23.799999237060547, 9.0, 12.399999618530273, 19.0, 20.799999237060547, 19.799999237060547, 18.399999618530273, 22.799999237060547, 21.19999885559082] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:mid, v [70.4000015258789, 67.0, 75.19999694824219, 69.0, 89.79999542236328, 93.0, 84.0, 81.5999984741211, 77.19999694824219, 67.19999694824219, 84.5999984741211, 86.0, 80.4000015258789, 75.79999542236328, 86.79999542236328, 80.79999542236328] +k eval/a0:test_return:Overcooked-CrampedRoom6_9:high, v [97.0, 94.5999984741211, 104.0, 102.79999542236328, 104.0, 110.0, 76.4000015258789, 78.5999984741211, 98.4000015258789, 96.5999984741211, 105.5999984741211, 97.79999542236328, 95.4000015258789, 104.19999694824219, 99.0, 98.4000015258789] +-------------------------------------------------------------------------------------------------- +| eval/a0:test_return:Overcooked-CrampedRoom6_9 | 64.97+/- 5.082 (max: 110.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 97.67+/- 2.235 (max: 110.0) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 17.94+/- 1.098 (max: 23.8) | +| eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 79.3+/- 2.015 (max: 93.0) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 25.19+/- 1.1 (max: 35.14) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 30.55+/- 0.7829 (max: 35.14) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 15.09+/- 0.382 (max: 17.16) | +| eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 29.93+/- 0.6353 (max: 33.49) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.7021+/- 0.05005 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.96+/- 0.007692 (max: 1.0) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.225+/- 0.02422 (max: 0.37) | +| eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.9212+/- 0.009259 (max: 0.99) | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9 | 9.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:high | 76.4 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:low | 9.0 | +| min:eval/a0:test_return:Overcooked-CrampedRoom6_9:mid | 67.0 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9 | 11.45 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:high | 23.75 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:low | 11.45 | +| min:eval/a0:test_return_std:Overcooked-CrampedRoom6_9:mid | 25.51 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9 | 0.04 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:high | 0.88 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:low | 0.04 | +| min:eval/a0:test_solved_rate:Overcooked-CrampedRoom6_9:mid | 0.85 | +-------------------------------------------------------------------------------------------------- diff --git a/src/train_baseline_dr_lstm.sh b/src/train_baseline_dr_lstm.sh new file mode 100755 index 0000000..1ea62df --- /dev/null +++ b/src/train_baseline_dr_lstm.sh @@ -0,0 +1,67 @@ +DEFAULTVALUE=4 +DEFAULTSEED=2 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=${seed} \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=dr \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--is_multi_agent=True \ +--verbose=False \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=1 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=400 \ +--lr=0.0003 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--student_gae_lambda=0.98 \ +--student_entropy_coef=0.01 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=8 \ +--student_ppo_n_minibatches=4 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--student_recurrent_arch=lstm \ +--student_recurrent_hidden_dim=64 \ +--student_hidden_dim=64 \ +--student_n_hidden_layers=3 \ +--student_n_conv_layers=3 \ +--student_n_conv_filters=32 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_agent_kind=mappo \ +--overcooked_height=6 \ +--overcooked_width=9 \ +--overcooked_n_walls=15 \ +--overcooked_replace_wall_pos=True \ +--overcooked_sample_n_walls=True \ +--overcooked_normalize_obs=True \ +--overcooked_max_steps=400 \ +--overcooked_random_reset=False \ +--n_shaped_reward_updates=30000 \ +--test_n_episodes=10 \ +--test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--overcooked_test_normalize_obs=True \ +--xpid=SEED_${seed}_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 \ No newline at end of file diff --git a/src/train_baseline_dr_s5.sh b/src/train_baseline_dr_s5.sh new file mode 100755 index 0000000..db78691 --- /dev/null +++ b/src/train_baseline_dr_s5.sh @@ -0,0 +1,71 @@ +DEFAULTVALUE=4 +DEFAULTSEED=2 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=${seed} \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=dr \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--is_multi_agent=True \ +--verbose=False \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=1 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=400 \ +--lr=0.0003 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--student_gae_lambda=0.98 \ +--student_entropy_coef=0.01 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=8 \ +--student_ppo_n_minibatches=4 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--student_recurrent_arch=s5 \ +--student_recurrent_hidden_dim=64 \ +--student_hidden_dim=64 \ +--student_n_hidden_layers=3 \ +--student_n_conv_layers=3 \ +--student_n_conv_filters=32 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_s5_n_blocks=2 \ +--student_s5_n_layers=2 \ +--student_s5_layernorm_pos=pre \ +--student_s5_activation=half_glu1 \ +--student_agent_kind=mappo \ +--overcooked_height=6 \ +--overcooked_width=9 \ +--overcooked_n_walls=15 \ +--overcooked_replace_wall_pos=True \ +--overcooked_sample_n_walls=True \ +--overcooked_normalize_obs=True \ +--overcooked_max_steps=400 \ +--overcooked_random_reset=False \ +--n_shaped_reward_updates=30000 \ +--test_n_episodes=10 \ +--test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--overcooked_test_normalize_obs=True \ +--xpid=SEED_${seed}_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 \ No newline at end of file diff --git a/src/train_baseline_dr_softmoe_lstm.sh b/src/train_baseline_dr_softmoe_lstm.sh new file mode 100755 index 0000000..21ca311 --- /dev/null +++ b/src/train_baseline_dr_softmoe_lstm.sh @@ -0,0 +1,70 @@ +DEFAULTVALUE=4 +DEFAULTSEED=2 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=${seed} \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=dr \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--is_multi_agent=True \ +--verbose=False \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=1 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=400 \ +--lr=0.0003 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--student_gae_lambda=0.98 \ +--student_entropy_coef=0.01 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=8 \ +--student_ppo_n_minibatches=4 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--student_recurrent_arch=lstm \ +--student_recurrent_hidden_dim=64 \ +--student_hidden_dim=64 \ +--student_n_hidden_layers=2 \ +--student_is_soft_moe=True \ +--student_soft_moe_num_experts=4 \ +--student_soft_moe_num_slots=32 \ +--student_n_conv_layers=3 \ +--student_n_conv_filters=32 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_agent_kind=mappo \ +--overcooked_height=6 \ +--overcooked_width=9 \ +--overcooked_n_walls=15 \ +--overcooked_replace_wall_pos=True \ +--overcooked_sample_n_walls=True \ +--overcooked_normalize_obs=True \ +--overcooked_max_steps=400 \ +--overcooked_random_reset=False \ +--n_shaped_reward_updates=30000 \ +--test_n_episodes=10 \ +--test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--overcooked_test_normalize_obs=True \ +--xpid=SEED_${seed}_dr-overcooked6x9w15_fs_IMAGE-r1s_32p_1e_400t_ae1e-05-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 \ No newline at end of file diff --git a/src/train_baseline_p_accel_lstm.sh b/src/train_baseline_p_accel_lstm.sh new file mode 100755 index 0000000..1e1d893 --- /dev/null +++ b/src/train_baseline_p_accel_lstm.sh @@ -0,0 +1,81 @@ +DEFAULTVALUE=4 +DEFAULTSEED=2 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=${seed} \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=plr \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--is_multi_agent=True \ +--verbose=False \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=1 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=400 \ +--lr=0.0003 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--ued_score=max_mc \ +--plr_replay_prob=0.8 \ +--plr_buffer_size=4000 \ +--plr_staleness_coef=0.3 \ +--plr_temp=0.1 \ +--plr_use_score_ranks=True \ +--plr_min_fill_ratio=0.5 \ +--plr_use_robust_plr=True \ +--plr_use_parallel_eval=True \ +--plr_force_unique=True \ +--plr_mutation_fn=default \ +--plr_n_mutations=20 \ +--plr_mutation_criterion=batch \ +--plr_mutation_subsample_size=4 \ +--student_gae_lambda=0.98 \ +--student_entropy_coef=0.01 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=8 \ +--student_ppo_n_minibatches=4 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--student_recurrent_arch=lstm \ +--student_recurrent_hidden_dim=64 \ +--student_hidden_dim=64 \ +--student_n_hidden_layers=3 \ +--student_n_conv_layers=3 \ +--student_n_conv_filters=32 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_agent_kind=mappo \ +--overcooked_height=6 \ +--overcooked_width=9 \ +--overcooked_n_walls=15 \ +--overcooked_replace_wall_pos=True \ +--overcooked_sample_n_walls=True \ +--overcooked_normalize_obs=True \ +--overcooked_max_steps=400 \ +--overcooked_random_reset=False \ +--n_shaped_reward_updates=30000 \ +--test_n_episodes=10 \ +--test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--overcooked_test_normalize_obs=True \ +--xpid=SEED_${seed}_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 \ No newline at end of file diff --git a/src/train_baseline_p_accel_s5.sh b/src/train_baseline_p_accel_s5.sh new file mode 100755 index 0000000..60416e5 --- /dev/null +++ b/src/train_baseline_p_accel_s5.sh @@ -0,0 +1,85 @@ +DEFAULTVALUE=4 +DEFAULTSEED=2 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=${seed} \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=plr \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--is_multi_agent=True \ +--verbose=False \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=1 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=400 \ +--lr=0.0003 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--ued_score=max_mc \ +--plr_replay_prob=0.8 \ +--plr_buffer_size=4000 \ +--plr_staleness_coef=0.3 \ +--plr_temp=0.1 \ +--plr_use_score_ranks=True \ +--plr_min_fill_ratio=0.5 \ +--plr_use_robust_plr=True \ +--plr_use_parallel_eval=True \ +--plr_force_unique=True \ +--plr_mutation_fn=default \ +--plr_n_mutations=20 \ +--plr_mutation_criterion=batch \ +--plr_mutation_subsample_size=4 \ +--student_gae_lambda=0.98 \ +--student_entropy_coef=0.01 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=8 \ +--student_ppo_n_minibatches=4 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--student_recurrent_arch=s5 \ +--student_recurrent_hidden_dim=64 \ +--student_hidden_dim=64 \ +--student_n_hidden_layers=3 \ +--student_n_conv_layers=3 \ +--student_n_conv_filters=32 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_s5_n_blocks=2 \ +--student_s5_n_layers=2 \ +--student_s5_layernorm_pos=pre \ +--student_s5_activation=half_glu1 \ +--student_agent_kind=mappo \ +--overcooked_height=6 \ +--overcooked_width=9 \ +--overcooked_n_walls=15 \ +--overcooked_replace_wall_pos=True \ +--overcooked_sample_n_walls=True \ +--overcooked_normalize_obs=True \ +--overcooked_max_steps=400 \ +--overcooked_random_reset=False \ +--n_shaped_reward_updates=30000 \ +--test_n_episodes=10 \ +--test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--overcooked_test_normalize_obs=True \ +--xpid=SEED_${seed}_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 \ No newline at end of file diff --git a/src/train_baseline_p_accel_softmoe_lstm.sh b/src/train_baseline_p_accel_softmoe_lstm.sh new file mode 100755 index 0000000..b085621 --- /dev/null +++ b/src/train_baseline_p_accel_softmoe_lstm.sh @@ -0,0 +1,84 @@ +DEFAULTVALUE=4 +DEFAULTSEED=2 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=${seed} \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=plr \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--is_multi_agent=True \ +--verbose=False \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=1 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=400 \ +--lr=0.0003 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--ued_score=max_mc \ +--plr_replay_prob=0.8 \ +--plr_buffer_size=4000 \ +--plr_staleness_coef=0.3 \ +--plr_temp=0.1 \ +--plr_use_score_ranks=True \ +--plr_min_fill_ratio=0.5 \ +--plr_use_robust_plr=True \ +--plr_use_parallel_eval=True \ +--plr_force_unique=True \ +--plr_mutation_fn=default \ +--plr_n_mutations=20 \ +--plr_mutation_criterion=batch \ +--plr_mutation_subsample_size=4 \ +--student_gae_lambda=0.98 \ +--student_entropy_coef=0.01 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=8 \ +--student_ppo_n_minibatches=4 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--student_recurrent_arch=lstm \ +--student_recurrent_hidden_dim=64 \ +--student_hidden_dim=64 \ +--student_n_hidden_layers=2 \ +--student_is_soft_moe=True \ +--student_soft_moe_num_experts=4 \ +--student_soft_moe_num_slots=32 \ +--student_n_conv_layers=3 \ +--student_n_conv_filters=32 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_agent_kind=mappo \ +--overcooked_height=6 \ +--overcooked_width=9 \ +--overcooked_n_walls=15 \ +--overcooked_replace_wall_pos=True \ +--overcooked_sample_n_walls=True \ +--overcooked_normalize_obs=True \ +--overcooked_max_steps=400 \ +--overcooked_random_reset=False \ +--n_shaped_reward_updates=30000 \ +--test_n_episodes=10 \ +--test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--overcooked_test_normalize_obs=True \ +--xpid=SEED_${seed}_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.8b4000t0.1s0.3m0.5r_mdef20bat_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 \ No newline at end of file diff --git a/src/train_baseline_p_plr_lstm.sh b/src/train_baseline_p_plr_lstm.sh new file mode 100755 index 0000000..bb0ccd5 --- /dev/null +++ b/src/train_baseline_p_plr_lstm.sh @@ -0,0 +1,77 @@ +DEFAULTVALUE=4 +DEFAULTSEED=2 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=${seed} \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=plr \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--is_multi_agent=True \ +--verbose=False \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=1 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=400 \ +--lr=0.0003 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--ued_score=max_mc \ +--plr_replay_prob=0.5 \ +--plr_buffer_size=4000 \ +--plr_staleness_coef=0.3 \ +--plr_temp=0.1 \ +--plr_use_score_ranks=True \ +--plr_min_fill_ratio=0.5 \ +--plr_use_robust_plr=True \ +--plr_use_parallel_eval=True \ +--plr_force_unique=True \ +--student_gae_lambda=0.98 \ +--student_entropy_coef=0.01 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=8 \ +--student_ppo_n_minibatches=4 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--student_recurrent_arch=lstm \ +--student_recurrent_hidden_dim=64 \ +--student_hidden_dim=64 \ +--student_n_hidden_layers=3 \ +--student_n_conv_layers=3 \ +--student_n_conv_filters=32 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_agent_kind=mappo \ +--overcooked_height=6 \ +--overcooked_width=9 \ +--overcooked_n_walls=15 \ +--overcooked_replace_wall_pos=True \ +--overcooked_sample_n_walls=True \ +--overcooked_normalize_obs=True \ +--overcooked_max_steps=400 \ +--overcooked_random_reset=False \ +--n_shaped_reward_updates=30000 \ +--test_n_episodes=10 \ +--test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--overcooked_test_normalize_obs=True \ +--xpid=SEED_${seed}_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_0 \ No newline at end of file diff --git a/src/train_baseline_p_plr_s5.sh b/src/train_baseline_p_plr_s5.sh new file mode 100755 index 0000000..c8e31ec --- /dev/null +++ b/src/train_baseline_p_plr_s5.sh @@ -0,0 +1,81 @@ +DEFAULTVALUE=4 +DEFAULTSEED=2 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=${seed} \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=plr \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--is_multi_agent=True \ +--verbose=False \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=1 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=400 \ +--lr=0.0003 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--ued_score=max_mc \ +--plr_replay_prob=0.5 \ +--plr_buffer_size=4000 \ +--plr_staleness_coef=0.3 \ +--plr_temp=0.1 \ +--plr_use_score_ranks=True \ +--plr_min_fill_ratio=0.5 \ +--plr_use_robust_plr=True \ +--plr_use_parallel_eval=True \ +--plr_force_unique=True \ +--student_gae_lambda=0.98 \ +--student_entropy_coef=0.01 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=8 \ +--student_ppo_n_minibatches=4 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--student_recurrent_arch=s5 \ +--student_recurrent_hidden_dim=64 \ +--student_hidden_dim=64 \ +--student_n_hidden_layers=3 \ +--student_n_conv_layers=3 \ +--student_n_conv_filters=32 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_s5_n_blocks=2 \ +--student_s5_n_layers=2 \ +--student_s5_layernorm_pos=pre \ +--student_s5_activation=half_glu1 \ +--student_agent_kind=mappo \ +--overcooked_height=6 \ +--overcooked_width=9 \ +--overcooked_n_walls=15 \ +--overcooked_replace_wall_pos=True \ +--overcooked_sample_n_walls=True \ +--overcooked_normalize_obs=True \ +--overcooked_max_steps=400 \ +--overcooked_random_reset=False \ +--n_shaped_reward_updates=30000 \ +--test_n_episodes=10 \ +--test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--overcooked_test_normalize_obs=True \ +--xpid=SEED_${seed}_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_0 \ No newline at end of file diff --git a/src/train_baseline_p_plr_softmoe_lstm.sh b/src/train_baseline_p_plr_softmoe_lstm.sh new file mode 100755 index 0000000..03f17d4 --- /dev/null +++ b/src/train_baseline_p_plr_softmoe_lstm.sh @@ -0,0 +1,80 @@ +DEFAULTVALUE=4 +DEFAULTSEED=2 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=${seed} \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=plr \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--is_multi_agent=True \ +--verbose=False \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=1 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=400 \ +--lr=0.0003 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--ued_score=max_mc \ +--plr_replay_prob=0.5 \ +--plr_buffer_size=4000 \ +--plr_staleness_coef=0.3 \ +--plr_temp=0.1 \ +--plr_use_score_ranks=True \ +--plr_min_fill_ratio=0.5 \ +--plr_use_robust_plr=True \ +--plr_use_parallel_eval=True \ +--plr_force_unique=True \ +--student_gae_lambda=0.98 \ +--student_entropy_coef=0.01 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=8 \ +--student_ppo_n_minibatches=4 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--student_recurrent_arch=lstm \ +--student_recurrent_hidden_dim=64 \ +--student_hidden_dim=64 \ +--student_n_hidden_layers=2 \ +--student_is_soft_moe=True \ +--student_soft_moe_num_experts=4 \ +--student_soft_moe_num_slots=32 \ +--student_n_conv_layers=3 \ +--student_n_conv_filters=32 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_agent_kind=mappo \ +--overcooked_height=6 \ +--overcooked_width=9 \ +--overcooked_n_walls=15 \ +--overcooked_replace_wall_pos=True \ +--overcooked_sample_n_walls=True \ +--overcooked_normalize_obs=True \ +--overcooked_max_steps=400 \ +--overcooked_random_reset=False \ +--n_shaped_reward_updates=30000 \ +--test_n_episodes=10 \ +--test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--overcooked_test_normalize_obs=True \ +--xpid=SEED_${seed}_plr-overcooked6x9w15_fs_IMAGE-rpf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_400t_ae1e-05_smm-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___0 \ No newline at end of file diff --git a/src/train_baseline_pop_paired_lstm.sh b/src/train_baseline_pop_paired_lstm.sh new file mode 100755 index 0000000..e0b1953 --- /dev/null +++ b/src/train_baseline_pop_paired_lstm.sh @@ -0,0 +1,89 @@ +DEFAULTVALUE=4 +DEFAULTSEED=2 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=${seed} \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=paired \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--verbose=False \ +--is_multi_agent=True \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=2 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=400 \ +--lr=0.0003 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--ued_score=relative_regret \ +--student_gae_lambda=0.98 \ +--teacher_discount=0.999 \ +--teacher_lr_anneal_steps=0 \ +--teacher_gae_lambda=0.98 \ +--student_entropy_coef=0.01 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=8 \ +--student_ppo_n_minibatches=4 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--teacher_entropy_coef=0.01 \ +--teacher_value_loss_coef=0.5 \ +--teacher_n_unroll_update=5 \ +--teacher_ppo_n_epochs=8 \ +--teacher_ppo_n_minibatches=4 \ +--teacher_ppo_clip_eps=0.2 \ +--teacher_ppo_clip_value_loss=True \ +--student_recurrent_arch=lstm \ +--student_recurrent_hidden_dim=64 \ +--student_hidden_dim=64 \ +--student_n_hidden_layers=3 \ +--student_n_conv_layers=3 \ +--student_n_conv_filters=32 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_agent_kind=mappo \ +--teacher_model_name=default_teacher_cnn \ +--teacher_recurrent_arch=lstm \ +--teacher_recurrent_hidden_dim=64 \ +--teacher_hidden_dim=64 \ +--teacher_n_hidden_layers=1 \ +--teacher_n_conv_filters=128 \ +--teacher_scalar_embed_dim=10 \ +--overcooked_height=6 \ +--overcooked_width=9 \ +--overcooked_n_walls=5 \ +--overcooked_normalize_obs=True \ +--overcooked_max_steps=400 \ +--overcooked_random_reset=False \ +--overcooked_ued_replace_wall_pos=True \ +--overcooked_ued_fixed_n_wall_steps=False \ +--overcooked_ued_first_wall_pos_sets_budget=True \ +--overcooked_ued_noise_dim=50 \ +--overcooked_ued_n_walls=15 \ +--overcooked_ued_normalize_obs=True \ +--n_shaped_reward_updates=30000 \ +--test_n_episodes=10 \ +--test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--overcooked_test_normalize_obs=True \ +--xpid=SEED_${seed}_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lstm_h64_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 \ No newline at end of file diff --git a/src/train_baseline_pop_paired_s5.sh b/src/train_baseline_pop_paired_s5.sh new file mode 100755 index 0000000..06fdd82 --- /dev/null +++ b/src/train_baseline_pop_paired_s5.sh @@ -0,0 +1,93 @@ +DEFAULTVALUE=4 +DEFAULTSEED=2 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=${seed} \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=paired \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--verbose=False \ +--is_multi_agent=True \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=2 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=400 \ +--lr=0.0003 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--ued_score=relative_regret \ +--student_gae_lambda=0.98 \ +--teacher_discount=0.999 \ +--teacher_lr_anneal_steps=0 \ +--teacher_gae_lambda=0.98 \ +--student_entropy_coef=0.01 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=8 \ +--student_ppo_n_minibatches=4 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--teacher_entropy_coef=0.01 \ +--teacher_value_loss_coef=0.5 \ +--teacher_n_unroll_update=5 \ +--teacher_ppo_n_epochs=8 \ +--teacher_ppo_n_minibatches=4 \ +--teacher_ppo_clip_eps=0.2 \ +--teacher_ppo_clip_value_loss=True \ +--student_recurrent_arch=s5 \ +--student_recurrent_hidden_dim=64 \ +--student_hidden_dim=64 \ +--student_n_hidden_layers=3 \ +--student_n_conv_layers=3 \ +--student_n_conv_filters=32 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_s5_n_blocks=2 \ +--student_s5_n_layers=2 \ +--student_s5_layernorm_pos=pre \ +--student_s5_activation=half_glu1 \ +--student_agent_kind=mappo \ +--teacher_model_name=default_teacher_cnn \ +--teacher_recurrent_arch=lstm \ +--teacher_recurrent_hidden_dim=64 \ +--teacher_hidden_dim=64 \ +--teacher_n_hidden_layers=1 \ +--teacher_n_conv_filters=128 \ +--teacher_scalar_embed_dim=10 \ +--overcooked_height=6 \ +--overcooked_width=9 \ +--overcooked_n_walls=5 \ +--overcooked_normalize_obs=True \ +--overcooked_max_steps=400 \ +--overcooked_random_reset=False \ +--overcooked_ued_replace_wall_pos=True \ +--overcooked_ued_fixed_n_wall_steps=False \ +--overcooked_ued_first_wall_pos_sets_budget=True \ +--overcooked_ued_noise_dim=50 \ +--overcooked_ued_n_walls=15 \ +--overcooked_ued_normalize_obs=True \ +--n_shaped_reward_updates=30000 \ +--test_n_episodes=10 \ +--test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--overcooked_test_normalize_obs=True \ +--xpid=SEED_${seed}_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc3se5ba_re_lpr_ahg1_s5_h64nb2nl2_tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 \ No newline at end of file diff --git a/src/train_baseline_pop_paired_softmoe_lstm.sh b/src/train_baseline_pop_paired_softmoe_lstm.sh new file mode 100755 index 0000000..ffdfe1a --- /dev/null +++ b/src/train_baseline_pop_paired_softmoe_lstm.sh @@ -0,0 +1,92 @@ +DEFAULTVALUE=4 +DEFAULTSEED=2 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=${seed} \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=paired \ +--n_devices=1 \ +--student_model_name=default_student_actor_cnn \ +--student_critic_model_name=default_student_critic_cnn \ +--env_name=Overcooked \ +--verbose=False \ +--is_multi_agent=True \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=2 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=400 \ +--lr=0.0003 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--ued_score=relative_regret \ +--student_gae_lambda=0.98 \ +--teacher_discount=0.999 \ +--teacher_lr_anneal_steps=0 \ +--teacher_gae_lambda=0.98 \ +--student_entropy_coef=0.01 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=8 \ +--student_ppo_n_minibatches=4 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--teacher_entropy_coef=0.01 \ +--teacher_value_loss_coef=0.5 \ +--teacher_n_unroll_update=5 \ +--teacher_ppo_n_epochs=8 \ +--teacher_ppo_n_minibatches=4 \ +--teacher_ppo_clip_eps=0.2 \ +--teacher_ppo_clip_value_loss=True \ +--student_recurrent_arch=lstm \ +--student_recurrent_hidden_dim=64 \ +--student_hidden_dim=64 \ +--student_n_hidden_layers=2 \ +--student_is_soft_moe=True \ +--student_soft_moe_num_experts=4 \ +--student_soft_moe_num_slots=32 \ +--student_n_conv_layers=3 \ +--student_n_conv_filters=32 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_agent_kind=mappo \ +--teacher_model_name=default_teacher_cnn \ +--teacher_recurrent_arch=lstm \ +--teacher_recurrent_hidden_dim=64 \ +--teacher_hidden_dim=64 \ +--teacher_n_hidden_layers=1 \ +--teacher_n_conv_filters=128 \ +--teacher_scalar_embed_dim=10 \ +--overcooked_height=6 \ +--overcooked_width=9 \ +--overcooked_n_walls=5 \ +--overcooked_normalize_obs=True \ +--overcooked_max_steps=400 \ +--overcooked_random_reset=False \ +--overcooked_ued_replace_wall_pos=True \ +--overcooked_ued_fixed_n_wall_steps=False \ +--overcooked_ued_first_wall_pos_sets_budget=True \ +--overcooked_ued_noise_dim=50 \ +--overcooked_ued_n_walls=15 \ +--overcooked_ued_normalize_obs=True \ +--n_shaped_reward_updates=30000 \ +--test_n_episodes=10 \ +--test_env_names=Overcooked-CoordRing6_9,Overcooked-ForcedCoord6_9,Overcooked-CounterCircuit6_9,Overcooked-AsymmAdvantages6_9,Overcooked-CrampedRoom6_9 \ +--overcooked_test_normalize_obs=True \ +--xpid=SEED_${seed}_paired-overcooked6x9w5_ld50_rb-r2s_32p_1e_400t_ae1e-05_sr-ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98_pc0.2_h64cf32fc2se5ba_re_lstm_h64__SoftMoE_4E_32S___tch_ppo_lr0.0003g0.999cv0.5ce0.01e8mb4l0.98pc0.2_h64cf128fc1se10ba_re_lstm_h64_0 \ No newline at end of file diff --git a/src/train_baselines_lstm6x9.sh b/src/train_baselines_lstm6x9.sh new file mode 100755 index 0000000..ee83274 --- /dev/null +++ b/src/train_baselines_lstm6x9.sh @@ -0,0 +1,10 @@ +DEFAULTVALUE=4 +DEFAULTSEED=1 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +echo "Using device ${device} and seed ${seed}" + +./train_baseline_p_plr_lstm.sh $device $seed +./train_baseline_p_accel_lstm.sh $device $seed +./train_baseline_pop_paired_lstm.sh $device $seed +./train_baseline_dr_lstm.sh $device $seed diff --git a/src/train_baselines_s56x9.sh b/src/train_baselines_s56x9.sh new file mode 100755 index 0000000..67e746c --- /dev/null +++ b/src/train_baselines_s56x9.sh @@ -0,0 +1,10 @@ +DEFAULTVALUE=4 +DEFAULTSEED=1 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +echo "Using device ${device} and seed ${seed}" + +./train_baseline_p_plr_s5.sh $device $seed +./train_baseline_p_accel_s5.sh $device $seed +./train_baseline_pop_paired_s5.sh $device $seed +./train_baseline_dr_s5.sh $device $seed diff --git a/src/train_baselines_softmoe_lstm6x9.sh b/src/train_baselines_softmoe_lstm6x9.sh new file mode 100755 index 0000000..6c04370 --- /dev/null +++ b/src/train_baselines_softmoe_lstm6x9.sh @@ -0,0 +1,10 @@ +DEFAULTVALUE=4 +DEFAULTSEED=1 +device="${1:-$DEFAULTVALUE}" +seed="${2:-$DEFAULTSEED}" +echo "Using device ${device} and seed ${seed}" + +./train_baseline_p_plr_softmoe_lstm.sh $device $seed +./train_baseline_p_accel_softmoe_lstm.sh $device $seed +./train_baseline_pop_paired_softmoe_lstm.sh $device $seed +./train_baseline_dr_softmoe_lstm.sh $device $seed diff --git a/src/train_maze.sh b/src/train_maze.sh new file mode 100755 index 0000000..7dcddfc --- /dev/null +++ b/src/train_maze.sh @@ -0,0 +1,67 @@ +CUDA_VISIBLE_DEVICES=$1 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--seed=1 \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=plr \ +--n_devices=1 \ +--student_model_name=default_student_cnn \ +--env_name=Maze \ +--verbose=False \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=False \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=1 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=256 \ +--lr=5e-05 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.995 \ +--n_unroll_rollout=10 \ +--render=False \ +--ued_score=max_mc \ +--plr_replay_prob=0.5 \ +--plr_buffer_size=4000 \ +--plr_staleness_coef=0.3 \ +--plr_temp=0.1 \ +--plr_use_score_ranks=True \ +--plr_min_fill_ratio=0.5 \ +--plr_use_robust_plr=True \ +--plr_use_parallel_eval=False \ +--plr_force_unique=True \ +--student_gae_lambda=0.98 \ +--student_entropy_coef=0.0 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=5 \ +--student_ppo_n_minibatches=1 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--student_recurrent_arch=lstm \ +--student_recurrent_hidden_dim=256 \ +--student_hidden_dim=32 \ +--student_n_hidden_layers=1 \ +--student_n_conv_filters=16 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--maze_height=13 \ +--maze_width=13 \ +--maze_n_walls=60 \ +--maze_replace_wall_pos=True \ +--maze_sample_n_walls=False \ +--maze_see_agent=False \ +--maze_normalize_obs=True \ +--maze_obs_agent_pos=False \ +--maze_max_episode_steps=250 \ +--test_n_episodes=10 \ +--test_env_names=Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze \ +--maze_test_see_agent=False \ +--maze_test_normalize_obs=True \ +--xpid=plr-maze13x13w60na_f-rf_p0.5b4000t0.1s0.3m0.5r_r1s_32p_1e_256t_ae1e-05_smm-ppo_lr5e-05g0.995cv0.5ce0.0e5mb1l0.98_pc0.2_h32cf16fc1se5ba_re_lstm_h256_0 \ No newline at end of file diff --git a/src/train_maze_s5.sh b/src/train_maze_s5.sh new file mode 100755 index 0000000..0d96b18 --- /dev/null +++ b/src/train_maze_s5.sh @@ -0,0 +1,76 @@ +DEFAULTVALUE=4 +device="${1:-$DEFAULTVALUE}" +CUDA_VISIBLE_DEVICES=${device} XLA_PYTHON_CLIENT_MEM_FRACTION=.40 LD_LIBRARY_PATH="" nice -n 5 python3 -m minimax.train \ +--wandb_mode=online \ +--wandb_project=overcooked-minimax-jax \ +--wandb_entity=${WANDB_ENTITY} \ +--seed=1 \ +--agent_rl_algo=ppo \ +--n_total_updates=30000 \ +--train_runner=plr \ +--n_devices=1 \ +--student_model_name=default_student_cnn \ +--env_name=Maze \ +--verbose=False \ +--log_dir=~/logs/minimax \ +--log_interval=10 \ +--from_last_checkpoint=True \ +--checkpoint_interval=1000 \ +--archive_interval=0 \ +--archive_init_checkpoint=False \ +--test_interval=100 \ +--n_students=1 \ +--n_parallel=32 \ +--n_eval=1 \ +--n_rollout_steps=256 \ +--lr=3e-05 \ +--lr_anneal_steps=0 \ +--max_grad_norm=0.5 \ +--adam_eps=1e-05 \ +--track_env_metrics=True \ +--discount=0.999 \ +--n_unroll_rollout=10 \ +--render=False \ +--ued_score=max_mc \ +--plr_replay_prob=0.5 \ +--plr_buffer_size=4000 \ +--plr_staleness_coef=0.3 \ +--plr_temp=0.3 \ +--plr_use_score_ranks=True \ +--plr_min_fill_ratio=0.5 \ +--plr_use_robust_plr=True \ +--plr_use_parallel_eval=False \ +--plr_force_unique=True \ +--student_gae_lambda=0.98 \ +--student_entropy_coef=0.001 \ +--student_value_loss_coef=0.5 \ +--student_n_unroll_update=5 \ +--student_ppo_n_epochs=5 \ +--student_ppo_n_minibatches=1 \ +--student_ppo_clip_eps=0.2 \ +--student_ppo_clip_value_loss=True \ +--student_recurrent_arch=s5 \ +--student_recurrent_hidden_dim=256 \ +--student_hidden_dim=32 \ +--student_n_hidden_layers=1 \ +--student_n_conv_filters=16 \ +--student_n_scalar_embeddings=4 \ +--student_scalar_embed_dim=5 \ +--student_s5_n_blocks=2 \ +--student_s5_n_layers=2 \ +--student_s5_layernorm_pos=pre \ +--student_s5_activation=half_glu1 \ +--maze_height=13 \ +--maze_width=13 \ +--maze_n_walls=60 \ +--maze_replace_wall_pos=True \ +--maze_sample_n_walls=False \ +--maze_see_agent=False \ +--maze_normalize_obs=True \ +--maze_obs_agent_pos=False \ +--maze_max_episode_steps=250 \ +--test_n_episodes=10 \ +--test_env_names=Maze-SixteenRooms,Maze-Labyrinth,Maze-StandardMaze \ +--maze_test_see_agent=False \ +--maze_test_normalize_obs=True \ +--xpid=plr-maze13x13w60na_f-rf_p0.5b4000t0.3s0.3m0.5r_r1s_32p_1e_256t_ae1e-05_smm-ppo_lr3e-05g0.999cv0.5ce0.001e5mb1l0.98_pc0.2_h32cf16fc1se5ba_re_lpr_ahg1_s5_h256nb2nl2_0