Uploaded
This commit is contained in:
commit
04c4625cfe
11 changed files with 1330 additions and 0 deletions
95
README.md
Normal file
95
README.md
Normal file
|
@ -0,0 +1,95 @@
|
||||||
|
<div align="center">
|
||||||
|
<h1> SummAct: Uncovering User Intentions Through Interactive Behaviour Summarisation </h1>
|
||||||
|
|
||||||
|
**[Guanhua Zhang][4], [Mohamed Ahmed][3], [Zhiming Hu][5], [Andreas Bulling][6]** <br>
|
||||||
|
**ACM CHI 2025**, Yokohama, Japan <br>
|
||||||
|
**[[Project][2]]** **[[Paper][7]]** </div>
|
||||||
|
----------------
|
||||||
|
|
||||||
|
# Directory Structure
|
||||||
|
```
|
||||||
|
SummAct
|
||||||
|
│ README.md
|
||||||
|
│ environment.yml
|
||||||
|
│
|
||||||
|
└───preprocess
|
||||||
|
│ convert_dataset.py
|
||||||
|
│ create_steps.py
|
||||||
|
│
|
||||||
|
└───hf_bmt
|
||||||
|
│ hf_2_bmtrain.py
|
||||||
|
│ hf_2_bmtrain.sh
|
||||||
|
│ bmt_hf.py
|
||||||
|
│
|
||||||
|
└───train
|
||||||
|
│ train.py
|
||||||
|
│ train.sh
|
||||||
|
│
|
||||||
|
└───inference
|
||||||
|
│ inference.py
|
||||||
|
│ inference.sh
|
||||||
|
│
|
||||||
|
└───train
|
||||||
|
│ train.py
|
||||||
|
│ train.sh
|
||||||
|
│
|
||||||
|
└───inference
|
||||||
|
│ inference.py
|
||||||
|
│ inference.sh
|
||||||
|
|
||||||
|
```
|
||||||
|
# Setup
|
||||||
|
We recommend setting up a virtual environment using Anaconda. <br>
|
||||||
|
1. Create a conda environment and install dependencies
|
||||||
|
```shell
|
||||||
|
conda env create --name summact --file=env.yaml
|
||||||
|
conda activate summact
|
||||||
|
```
|
||||||
|
2. Since `model_center==1.0.3` is needed but is not yet available on PYPI, build from [source](https://github.com/OpenBMB/ModelCenter)
|
||||||
|
```
|
||||||
|
$ git clone https://github.com/OpenBMB/ModelCenter.git
|
||||||
|
$ cd ModelCenter
|
||||||
|
$ pip install -r requirements.txt
|
||||||
|
$ python3 setup.py install
|
||||||
|
```
|
||||||
|
3. Clone our repository to download our code and a pretrained model
|
||||||
|
```shell
|
||||||
|
git clone this_repo.git
|
||||||
|
```
|
||||||
|
|
||||||
|
# Preprocessing
|
||||||
|
1. Convert actions from symbolic formats to natural language by running `preprocess/convert_dataset.py`. Adapt it to your local dataset paths.
|
||||||
|
2. Prompting the pretrained LLM with examples to generate sub-intentions using `preprocess/create_steps.py`. Adapt it to your local prompt txt path.
|
||||||
|
|
||||||
|
# Fine-tuning
|
||||||
|
1. After downloading the model from Hugging Face, convert it into `model_center` weights using the script in `hf_bmt/hf_2_bmtrain.sh`. Adapt it to your local paths of the downloaded model and the wanted output.
|
||||||
|
2. Run `train/train.sh`, which will call `train/train.py` to fine-tune the model for interactive behaviour summarisation. Make sure your computer has GPUs.
|
||||||
|
|
||||||
|
# Inference
|
||||||
|
Run `inference/inference.sh`, which will call `inference/inference.py` to convert the fine-tuned model back to HF format, and then calculate metrics to evaluate the summarisation quality.
|
||||||
|
|
||||||
|
# Citation
|
||||||
|
If you find our code useful or use it in your own projects, please cite our paper:
|
||||||
|
```
|
||||||
|
@inproceedings{zhang25_chi,
|
||||||
|
title = {SummAct: Uncovering User Intentions Through Interactive Behaviour Summarisation},
|
||||||
|
author = {Zhang, Guanhua and Ahmed, Mohamed and Hu, Zhiming and Bulling, Andreas},
|
||||||
|
year = {2025},
|
||||||
|
pages = {1--17},
|
||||||
|
booktitle = {Proc. ACM SIGCHI Conference on Human Factors in Computing Systems (CHI)},
|
||||||
|
doi = {10.1145/3706598.3713190}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
# Acknowledgements
|
||||||
|
Our work relied on the codebase of [Mind2Web][1], [ScreenAgent][8] and [Tell Me More!][9]. Thanks to the authors for sharing their code.
|
||||||
|
|
||||||
|
[1]: https://osu-nlp-group.github.io/Mind2Web/
|
||||||
|
[2]: https://collaborative-ai.org/publications/zhang25_chi/
|
||||||
|
[3]: https://www.linkedin.com/in/mohamed-adel-naguib/
|
||||||
|
[4]: https://scholar.google.com/citations?user=NqkK0GwAAAAJ&hl=en
|
||||||
|
[5]: https://scholar.google.com/citations?hl=en&user=OLB_xBEAAAAJ
|
||||||
|
[6]: https://www.collaborative-ai.org/people/bulling/
|
||||||
|
[7]: https://collaborative-ai.org/publications/zhang25_chi.pdf
|
||||||
|
[8]: https://github.com/niuzaisheng/ScreenAgent
|
||||||
|
[9]: https://github.com/OpenBMB/Tell_Me_More
|
184
environment.yml
Normal file
184
environment.yml
Normal file
|
@ -0,0 +1,184 @@
|
||||||
|
name: Mistral
|
||||||
|
channels:
|
||||||
|
- defaults
|
||||||
|
- conda-forge
|
||||||
|
dependencies:
|
||||||
|
- _libgcc_mutex=0.1=conda_forge
|
||||||
|
- _openmp_mutex=4.5=2_gnu
|
||||||
|
- asttokens=2.4.1=pyhd8ed1ab_0
|
||||||
|
- bzip2=1.0.8=hd590300_5
|
||||||
|
- ca-certificates=2024.2.2=hbcca054_0
|
||||||
|
- comm=0.2.2=pyhd8ed1ab_0
|
||||||
|
- debugpy=1.8.1=py311hb755f60_0
|
||||||
|
- decorator=5.1.1=pyhd8ed1ab_0
|
||||||
|
- exceptiongroup=1.2.0=pyhd8ed1ab_2
|
||||||
|
- executing=2.0.1=pyhd8ed1ab_0
|
||||||
|
- importlib-metadata=7.1.0=pyha770c72_0
|
||||||
|
- importlib_metadata=7.1.0=hd8ed1ab_0
|
||||||
|
- ipykernel=6.29.3=pyhd33586a_0
|
||||||
|
- ipython=8.24.0=pyh707e725_0
|
||||||
|
- jedi=0.19.1=pyhd8ed1ab_0
|
||||||
|
- jupyter_client=8.6.1=pyhd8ed1ab_0
|
||||||
|
- jupyter_core=5.7.2=py311h38be061_0
|
||||||
|
- keyutils=1.6.1=h166bdaf_0
|
||||||
|
- krb5=1.21.2=h659d440_0
|
||||||
|
- ld_impl_linux-64=2.40=h41732ed_0
|
||||||
|
- libedit=3.1.20191231=he28a2e2_2
|
||||||
|
- libexpat=2.6.2=h59595ed_0
|
||||||
|
- libffi=3.4.2=h7f98852_5
|
||||||
|
- libgcc-ng=13.2.0=h807b86a_5
|
||||||
|
- libgomp=13.2.0=h807b86a_5
|
||||||
|
- libnsl=2.0.1=hd590300_0
|
||||||
|
- libsodium=1.0.18=h36c2ea0_1
|
||||||
|
- libsqlite=3.45.2=h2797004_0
|
||||||
|
- libstdcxx-ng=13.2.0=hc0a3c3a_7
|
||||||
|
- libuuid=2.38.1=h0b41bf4_0
|
||||||
|
- libxcrypt=4.4.36=hd590300_1
|
||||||
|
- libzlib=1.2.13=hd590300_5
|
||||||
|
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
|
||||||
|
- ncurses=6.4.20240210=h59595ed_0
|
||||||
|
- nest-asyncio=1.6.0=pyhd8ed1ab_0
|
||||||
|
- openssl=3.3.0=hd590300_0
|
||||||
|
- packaging=24.0=pyhd8ed1ab_0
|
||||||
|
- parso=0.8.4=pyhd8ed1ab_0
|
||||||
|
- pexpect=4.9.0=pyhd8ed1ab_0
|
||||||
|
- pickleshare=0.7.5=py_1003
|
||||||
|
- pip=24.0=pyhd8ed1ab_0
|
||||||
|
- platformdirs=4.2.1=pyhd8ed1ab_0
|
||||||
|
- prompt-toolkit=3.0.42=pyha770c72_0
|
||||||
|
- psutil=5.9.8=py311h459d7ec_0
|
||||||
|
- ptyprocess=0.7.0=pyhd3deb0d_0
|
||||||
|
- pure_eval=0.2.2=pyhd8ed1ab_0
|
||||||
|
- pygments=2.18.0=pyhd8ed1ab_0
|
||||||
|
- python=3.11.8=hab00c5b_0_cpython
|
||||||
|
- python_abi=3.11=4_cp311
|
||||||
|
- pyzmq=26.0.3=py311h08a0b41_0
|
||||||
|
- readline=8.2=h8228510_1
|
||||||
|
- setuptools=69.5.1=pyhd8ed1ab_0
|
||||||
|
- six=1.16.0=pyh6c4a22f_0
|
||||||
|
- stack_data=0.6.2=pyhd8ed1ab_0
|
||||||
|
- tk=8.6.13=noxft_h4845f30_101
|
||||||
|
- tornado=6.4=py311h459d7ec_0
|
||||||
|
- traitlets=5.14.3=pyhd8ed1ab_0
|
||||||
|
- typing_extensions=4.11.0=pyha770c72_0
|
||||||
|
- wcwidth=0.2.13=pyhd8ed1ab_0
|
||||||
|
- wheel=0.43.0=pyhd8ed1ab_1
|
||||||
|
- xz=5.2.6=h166bdaf_0
|
||||||
|
- zeromq=4.3.5=h75354e8_4
|
||||||
|
- zipp=3.17.0=pyhd8ed1ab_0
|
||||||
|
- pip:
|
||||||
|
- absl-py==2.1.0
|
||||||
|
- accelerate==0.29.2
|
||||||
|
- aiohttp==3.9.4
|
||||||
|
- aiosignal==1.3.1
|
||||||
|
- annotated-types==0.7.0
|
||||||
|
- antlr4-python3-runtime==4.9.3
|
||||||
|
- anyio==4.4.0
|
||||||
|
- appdirs==1.4.4
|
||||||
|
- attrs==23.2.0
|
||||||
|
- beautifulsoup4==4.12.3
|
||||||
|
- bmtrain==1.0.0
|
||||||
|
- bs4==0.0.2
|
||||||
|
- certifi==2024.2.2
|
||||||
|
- charset-normalizer==3.3.2
|
||||||
|
- click==8.1.7
|
||||||
|
- colorama==0.4.6
|
||||||
|
- cprint==1.2.2
|
||||||
|
- cython==0.29.37
|
||||||
|
- datasets==2.18.0
|
||||||
|
- dill==0.3.8
|
||||||
|
- distro==1.9.0
|
||||||
|
- docker-pycreds==0.4.0
|
||||||
|
- evaluate==0.4.1
|
||||||
|
- filelock==3.13.4
|
||||||
|
- frozenlist==1.4.1
|
||||||
|
- fsspec==2024.2.0
|
||||||
|
- gitdb==4.0.11
|
||||||
|
- gitpython==3.1.43
|
||||||
|
- grpcio==1.62.1
|
||||||
|
- h11==0.14.0
|
||||||
|
- hdbscan==0.8.37
|
||||||
|
- httpcore==1.0.5
|
||||||
|
- httpx==0.27.0
|
||||||
|
- huggingface-hub==0.22.2
|
||||||
|
- hydra-core==1.3.2
|
||||||
|
- idna==3.7
|
||||||
|
- jieba==0.42.1
|
||||||
|
- jinja2==3.1.3
|
||||||
|
- joblib==1.4.0
|
||||||
|
- keybert==0.8.5
|
||||||
|
- levenshtein==0.25.1
|
||||||
|
- lxml==5.2.1
|
||||||
|
- markdown==3.6
|
||||||
|
- markdown-it-py==3.0.0
|
||||||
|
- markupsafe==2.1.5
|
||||||
|
- mdurl==0.1.2
|
||||||
|
- mpmath==1.3.0
|
||||||
|
- multidict==6.0.5
|
||||||
|
- multiprocess==0.70.16
|
||||||
|
- networkx==3.3
|
||||||
|
- nltk==3.8.1
|
||||||
|
- numpy==1.26.4
|
||||||
|
- nvidia-cublas-cu12==12.1.3.1
|
||||||
|
- nvidia-cuda-cupti-cu12==12.1.105
|
||||||
|
- nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
|
- nvidia-cuda-runtime-cu12==12.1.105
|
||||||
|
- nvidia-cudnn-cu12==8.9.2.26
|
||||||
|
- nvidia-cufft-cu12==11.0.2.54
|
||||||
|
- nvidia-curand-cu12==10.3.2.106
|
||||||
|
- nvidia-cusolver-cu12==11.4.5.107
|
||||||
|
- nvidia-cusparse-cu12==12.1.0.106
|
||||||
|
- nvidia-nccl-cu11==2.21.5
|
||||||
|
- nvidia-nccl-cu12==2.19.3
|
||||||
|
- nvidia-nvjitlink-cu12==12.4.127
|
||||||
|
- nvidia-nvtx-cu12==12.1.105
|
||||||
|
- omegaconf==2.3.0
|
||||||
|
- openai==1.36.0
|
||||||
|
- pandas==2.2.2
|
||||||
|
- pdb-tools==2.5.0
|
||||||
|
- pillow==10.3.0
|
||||||
|
- portalocker==2.8.2
|
||||||
|
- protobuf==4.25.3
|
||||||
|
- pyarrow==15.0.2
|
||||||
|
- pyarrow-hotfix==0.6
|
||||||
|
- pydantic==2.8.2
|
||||||
|
- pydantic-core==2.20.1
|
||||||
|
- python-dateutil==2.9.0.post0
|
||||||
|
- pytz==2024.1
|
||||||
|
- pyyaml==6.0.1
|
||||||
|
- rapidfuzz==3.9.6
|
||||||
|
- regex==2023.12.25
|
||||||
|
- requests==2.31.0
|
||||||
|
- responses==0.18.0
|
||||||
|
- rich==13.7.1
|
||||||
|
- rouge-score==0.1.2
|
||||||
|
- sacrebleu==2.4.2
|
||||||
|
- safetensors==0.4.3
|
||||||
|
- scikit-learn==1.4.2
|
||||||
|
- scipy==1.13.0
|
||||||
|
- sentence-transformers==2.7.0
|
||||||
|
- sentencepiece==0.2.0
|
||||||
|
- sentry-sdk==1.45.0
|
||||||
|
- setproctitle==1.3.3
|
||||||
|
- smmap==5.0.1
|
||||||
|
- sniffio==1.3.1
|
||||||
|
- soupsieve==2.5
|
||||||
|
- sympy==1.12
|
||||||
|
- tabulate==0.9.0
|
||||||
|
- tensorboard==2.16.2
|
||||||
|
- tensorboard-data-server==0.7.2
|
||||||
|
- textblob==0.18.0.post0
|
||||||
|
- threadpoolctl==3.4.0
|
||||||
|
- tokenizers==0.15.2
|
||||||
|
- torch==2.2.2
|
||||||
|
- torchvision==0.17.2
|
||||||
|
- tqdm==4.66.2
|
||||||
|
- transformers==4.39.3
|
||||||
|
- triton==2.2.0
|
||||||
|
- tzdata==2024.1
|
||||||
|
- urllib3==2.2.1
|
||||||
|
- wandb==0.16.6
|
||||||
|
- werkzeug==3.0.2
|
||||||
|
- xxhash==3.4.1
|
||||||
|
- yarl==1.9.4
|
||||||
|
prefix: /opt/anaconda3/envs/Mistral
|
92
hf_bmt/bmt_hf.py
Normal file
92
hf_bmt/bmt_hf.py
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
import os, pdb
|
||||||
|
import json
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
import shutil
|
||||||
|
import argparse
|
||||||
|
from collections import OrderedDict
|
||||||
|
from transformers import AutoConfig, AutoModelForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
def transform_to_hf(bmt_model, model_size):
|
||||||
|
model_hf = OrderedDict()
|
||||||
|
|
||||||
|
if 'input_embedding.weight' in bmt_model.keys():
|
||||||
|
model_hf['model.embed_tokens.weight'] = bmt_model["input_embedding.weight"].contiguous().float()
|
||||||
|
model_hf['model.norm.weight'] = bmt_model["encoder.output_layernorm.weight"].contiguous().float()
|
||||||
|
try:
|
||||||
|
model_hf['lm_head.weight'] = bmt_model['output_projection.weight'].contiguous().float()
|
||||||
|
except:
|
||||||
|
model_hf['lm_head.weight'] = bmt_model["input_embedding.weight"].contiguous().float()
|
||||||
|
else:
|
||||||
|
model_hf['model.embed_tokens.weight'] = bmt_model["LLM.input_embedding.weight"].contiguous().float()
|
||||||
|
model_hf['model.norm.weight'] = bmt_model["LLM.encoder.output_layernorm.weight"].contiguous().float()
|
||||||
|
try:
|
||||||
|
model_hf['lm_head.weight'] = bmt_model['LLM.output_projection.weight'].contiguous().float()
|
||||||
|
except:
|
||||||
|
model_hf['lm_head.weight'] = bmt_model["LLM.input_embedding.weight"].contiguous().float()
|
||||||
|
|
||||||
|
if model_size == "7b":
|
||||||
|
layernum = 32
|
||||||
|
elif model_size == "13b" or model_size == "13b-2":
|
||||||
|
layernum = 40
|
||||||
|
elif model_size == "65b":
|
||||||
|
layernum = 80
|
||||||
|
|
||||||
|
for lnum in range(layernum):
|
||||||
|
hf_pfx = f"model.layers.{lnum}"
|
||||||
|
if 'input_embedding.weight' in bmt_model.keys():
|
||||||
|
bmt_pfx = f"encoder.layers.{lnum}"
|
||||||
|
else:
|
||||||
|
bmt_pfx = f"LLM.encoder.layers.{lnum}"
|
||||||
|
|
||||||
|
model_hf[f"{hf_pfx}.input_layernorm.weight"] = bmt_model[f"{bmt_pfx}.self_att.layernorm_before_attention.weight"].contiguous().float()
|
||||||
|
|
||||||
|
model_hf[f"{hf_pfx}.self_attn.q_proj.weight"] = bmt_model[f"{bmt_pfx}.self_att.self_attention.project_q.weight"].contiguous().float()
|
||||||
|
model_hf[f"{hf_pfx}.self_attn.k_proj.weight"] = bmt_model[f"{bmt_pfx}.self_att.self_attention.project_k.weight"].contiguous().float()
|
||||||
|
model_hf[f"{hf_pfx}.self_attn.v_proj.weight"] = bmt_model[f"{bmt_pfx}.self_att.self_attention.project_v.weight"].contiguous().float()
|
||||||
|
model_hf[f"{hf_pfx}.self_attn.o_proj.weight"] = bmt_model[f"{bmt_pfx}.self_att.self_attention.attention_out.weight"].contiguous().float()
|
||||||
|
|
||||||
|
model_hf[f"{hf_pfx}.post_attention_layernorm.weight"] = bmt_model[f"{bmt_pfx}.ffn.layernorm_before_ffn.weight"].contiguous().float()
|
||||||
|
|
||||||
|
model_hf[f"{hf_pfx}.mlp.gate_proj.weight"] = bmt_model[f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight"].contiguous().float()
|
||||||
|
model_hf[f"{hf_pfx}.mlp.up_proj.weight"] = bmt_model[f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight"].contiguous().float()
|
||||||
|
|
||||||
|
model_hf[f"{hf_pfx}.mlp.down_proj.weight"] = bmt_model[f"{bmt_pfx}.ffn.ffn.w_out.weight"].contiguous().float()
|
||||||
|
|
||||||
|
for key in model_hf:
|
||||||
|
model_hf[key] = model_hf[key].bfloat16()
|
||||||
|
return model_hf
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument("--in_path", type=str)
|
||||||
|
parser.add_argument("--output_path", type=str)
|
||||||
|
parser.add_argument("--original_mistral_path", type=str)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
os.makedirs(args.output_path, exist_ok=True)
|
||||||
|
print("transforming " + args.in_path)
|
||||||
|
|
||||||
|
model_size = "7b"
|
||||||
|
|
||||||
|
ckpt = [name for name in os.listdir(args.in_path) if name.endswith(".pt")]
|
||||||
|
bmt_model = torch.load(os.path.join(args.in_path, ckpt[0]))
|
||||||
|
|
||||||
|
hf_state_dict = transform_to_hf(bmt_model, model_size)
|
||||||
|
print(f"start saving to {args.output_path}")
|
||||||
|
|
||||||
|
model_config = AutoConfig.from_pretrained(args.original_mistral_path)
|
||||||
|
model = AutoModelForCausalLM.from_config(model_config)
|
||||||
|
model.load_state_dict(hf_state_dict)
|
||||||
|
|
||||||
|
for param in model.parameters():
|
||||||
|
param.data = param.data.to(torch.bfloat16)
|
||||||
|
|
||||||
|
model.save_pretrained(args.output_path, safe_serialization=False)
|
||||||
|
for file_name in ["tokenizer_config.json", "special_tokens_map.json", "tokenizer.model", "tokenizer.json"]:
|
||||||
|
if os.path.exists(os.path.join(args.in_path, file_name)):
|
||||||
|
shutil.copy(os.path.join(args.in_path, file_name), os.path.join(args.output_path, file_name))
|
||||||
|
print("saved huggingface checkpoint")
|
108
hf_bmt/hf_2_bmtrain.py
Normal file
108
hf_bmt/hf_2_bmtrain.py
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
from transformers import LlamaConfig
|
||||||
|
from transformers import AutoModelForCausalLM
|
||||||
|
import torch, os
|
||||||
|
import json
|
||||||
|
from collections import OrderedDict
|
||||||
|
import shutil, pdb
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
# get arguments
|
||||||
|
parser = argparse.ArgumentParser("")
|
||||||
|
# Output Directory for the bmt train weights.
|
||||||
|
parser.add_argument("--out_path", type=str, default=f"/Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24")
|
||||||
|
# Path where you downloaded mistral-7b hugging face weight
|
||||||
|
parser.add_argument('--in_path', type=str, default=f"/Mistral-{ver}-bmtrain")
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
ver = "7b"
|
||||||
|
# Change these two
|
||||||
|
# Output Directory for the bmt train weights.
|
||||||
|
# outpath = f"/Mistral-{ver}-bmtrain"
|
||||||
|
# Path where you downloaded mistral-7b hugging face weight
|
||||||
|
# inpath = f"/Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24"
|
||||||
|
def convert_weights(args):
|
||||||
|
hf_config = LlamaConfig.from_pretrained(args.in_path)
|
||||||
|
config = {
|
||||||
|
'dim_model': hf_config.hidden_size,
|
||||||
|
'dim_ff': hf_config.intermediate_size,
|
||||||
|
'num_layers': hf_config.num_hidden_layers,
|
||||||
|
'num_heads': hf_config.num_attention_heads,
|
||||||
|
'num_heads_kv': hf_config.num_key_value_heads,
|
||||||
|
'dim_head': hf_config.hidden_size // hf_config.num_attention_heads,
|
||||||
|
'norm_eps': hf_config.rms_norm_eps,
|
||||||
|
}
|
||||||
|
os.makedirs(args.out_path, exist_ok=True)
|
||||||
|
|
||||||
|
with open(os.path.join(args.out_path, "config.json"), 'w') as f:
|
||||||
|
json.dump(config, f)
|
||||||
|
|
||||||
|
layernum = config['num_layers']
|
||||||
|
|
||||||
|
model_hf = OrderedDict()
|
||||||
|
ckpt_num = None
|
||||||
|
if 'v0.1' in args.in_path:
|
||||||
|
prefix = "pytorch_model-"
|
||||||
|
endtext = ".bin"
|
||||||
|
else:
|
||||||
|
prefix = "model-"
|
||||||
|
endtext = ".safetensors"
|
||||||
|
for name in os.listdir(args.in_path):
|
||||||
|
if name.startswith(prefix) and name.endswith(endtext):
|
||||||
|
ckpt_num =int(name.split(endtext)[0].split('-')[-1])
|
||||||
|
for i in range(1, ckpt_num + 1):
|
||||||
|
if 'v0.1' in args.in_path:
|
||||||
|
part = torch.load(os.path.join(args.in_path, f"pytorch_model-{i:05d}-of-{ckpt_num:05d}.bin"))
|
||||||
|
else:
|
||||||
|
from safetensors import safe_open
|
||||||
|
with safe_open(os.path.join(args.in_path, f"model-{i:05d}-of-{ckpt_num:05d}.safetensors"), framework="pt", device=0) as f:
|
||||||
|
part = {}
|
||||||
|
for k in f.keys():
|
||||||
|
part[k] = f.get_tensor(k)
|
||||||
|
model_hf.update(part)
|
||||||
|
|
||||||
|
out = OrderedDict()
|
||||||
|
|
||||||
|
out["input_embedding.weight"] = model_hf['model.embed_tokens.weight'].contiguous()
|
||||||
|
out["encoder.output_layernorm.weight"] = model_hf['model.norm.weight'].contiguous()
|
||||||
|
out['output_projection.weight'] = model_hf['lm_head.weight'].contiguous()
|
||||||
|
for lnum in range(layernum):
|
||||||
|
hf_pfx = f"model.layers.{lnum}"
|
||||||
|
bmt_pfx = f"encoder.layers.{lnum}"
|
||||||
|
|
||||||
|
out[f"{bmt_pfx}.self_att.layernorm_before_attention.weight"] = model_hf[f"{hf_pfx}.input_layernorm.weight"].contiguous()
|
||||||
|
|
||||||
|
out[f"{bmt_pfx}.self_att.self_attention.project_q.weight"] = model_hf[f"{hf_pfx}.self_attn.q_proj.weight"].contiguous()
|
||||||
|
out[f"{bmt_pfx}.self_att.self_attention.project_k.weight"] = model_hf[f"{hf_pfx}.self_attn.k_proj.weight"].contiguous()
|
||||||
|
out[f"{bmt_pfx}.self_att.self_attention.project_v.weight"] = model_hf[f"{hf_pfx}.self_attn.v_proj.weight"].contiguous()
|
||||||
|
out[f"{bmt_pfx}.self_att.self_attention.attention_out.weight"] = model_hf[f"{hf_pfx}.self_attn.o_proj.weight"].contiguous()
|
||||||
|
|
||||||
|
out[f"{bmt_pfx}.ffn.layernorm_before_ffn.weight"] = model_hf[f"{hf_pfx}.post_attention_layernorm.weight"].contiguous()
|
||||||
|
|
||||||
|
out[f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight"] = model_hf[f"{hf_pfx}.mlp.gate_proj.weight"].contiguous()
|
||||||
|
out[f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight"] = model_hf[f"{hf_pfx}.mlp.up_proj.weight"].contiguous()
|
||||||
|
|
||||||
|
out[f"{bmt_pfx}.ffn.ffn.w_out.weight"] = model_hf[f"{hf_pfx}.mlp.down_proj.weight"].contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
for key in out:
|
||||||
|
out[key] = out[key].half()
|
||||||
|
|
||||||
|
if not os.path.exists(args.out_path):
|
||||||
|
os.makedirs(args.out_path)
|
||||||
|
torch.save(out, os.path.join(args.out_path, "pytorch_model.pt"))
|
||||||
|
|
||||||
|
for file_name in ["tokenizer_config.json", "special_tokens_map.json", "tokenizer.model", "tokenizer.json"]:
|
||||||
|
if os.path.exists(os.path.join(args.in_path, file_name)):
|
||||||
|
shutil.copy(os.path.join(args.in_path, file_name), os.path.join(args.out_path, file_name))
|
||||||
|
|
||||||
|
print("BMT weights created sucessfully")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = initialize()
|
||||||
|
convert_weights(args)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
13
hf_bmt/hf_2_bmtrain.sh
Normal file
13
hf_bmt/hf_2_bmtrain.sh
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
IN_PATH="your-path-to-hf-model"
|
||||||
|
OUT_PATH="your-wanted-path-to-bm-model"
|
||||||
|
|
||||||
|
OPTS=""
|
||||||
|
OPTS+="--in_path ${IN_PATH} "
|
||||||
|
OPTS+="--out_path ${OUT_PATH}"
|
||||||
|
|
||||||
|
CMD="python3 hf_2_bmtrain.py ${OPTS}"
|
||||||
|
|
||||||
|
echo "-------final CMD is------"
|
||||||
|
echo "${CMD}"
|
||||||
|
echo "-------final CMD end------"
|
||||||
|
eval ${CMD}
|
158
inference/inference.py
Normal file
158
inference/inference.py
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
import argparse
|
||||||
|
import os, pdb
|
||||||
|
import numpy as np
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
import json
|
||||||
|
from tqdm import tqdm
|
||||||
|
from cprint import cprint
|
||||||
|
import evaluate
|
||||||
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
from sentence_transformers import SentenceTransformer, util
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
parser = argparse.ArgumentParser("")
|
||||||
|
parser.add_argument("--model_name_or_path", type=str, default='')
|
||||||
|
parser.add_argument("--embedding_model_path", type=str, default="")
|
||||||
|
parser.add_argument("--train_data_dir", type=str, default='')
|
||||||
|
parser.add_argument("--test_data_dir", type=str, default='')
|
||||||
|
parser.add_argument("--prompt_file", type=str, default=None, help="The file for loading the prompt")
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
def get_tokenizer(args):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, device_map={"":0})
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.padding_side = 'left'
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def get_model(args):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, device_map={"":0})
|
||||||
|
return model
|
||||||
|
|
||||||
|
def setup_model_and_tokenizer(args):
|
||||||
|
tokenizer = get_tokenizer(args)
|
||||||
|
model = get_model(args)
|
||||||
|
return tokenizer, model
|
||||||
|
|
||||||
|
def read_json_file(filename):
|
||||||
|
with open(filename, 'r') as infile:
|
||||||
|
data = json.load(infile)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def format_one_action(action):
|
||||||
|
return f"- {action}\n"
|
||||||
|
|
||||||
|
def format_actions_list(actions):
|
||||||
|
actions_str = ""
|
||||||
|
for action in actions:
|
||||||
|
actions_str += format_one_action(action)
|
||||||
|
return actions_str
|
||||||
|
|
||||||
|
def preprocess_data(task, args):
|
||||||
|
with open(args.prompt_file, 'r') as file:
|
||||||
|
task_description = file.read().split('===')
|
||||||
|
|
||||||
|
input_str = f"## Website:\n{task['website_en']}\n\n## Domain:\n{task['domain_en']}\n\n## Sub-domain:\n{task['subdomain_en']}\n\n## Actions (Each line is one action):\n{format_actions_list(task['task_subintention'])}\n## Sub-intentions summarised from these actions:\n{format_actions_list(task['steps'])}"
|
||||||
|
query_inputs = f"{task_description[0]}\n{input_str}{task_description[1]}\n"
|
||||||
|
summary_str = task['task_description']
|
||||||
|
summary_str = summary_str[0].upper() + summary_str[1:] + "."
|
||||||
|
test_prompt = f"User: {query_inputs}\nAgent:"
|
||||||
|
return {"task": summary_str, "prompt": test_prompt}
|
||||||
|
|
||||||
|
def load_raw_dataset(data, args):
|
||||||
|
tasks = []
|
||||||
|
for d in tqdm(data):
|
||||||
|
processed_task = preprocess_data(d, args)
|
||||||
|
tasks.append(processed_task)
|
||||||
|
return tasks
|
||||||
|
|
||||||
|
def main_loop(args, test_dataset, tokenizer, model, sacrebleu, rouge, meteor, embedding_model, mark):
|
||||||
|
os.makedirs(args.model_name_or_path+"/results/", exist_ok=True)
|
||||||
|
global_sacrebleu, global_rouge1, global_rouge2, global_rougeL, global_rougeLsum, global_meteor, global_cosine, global_distance = [], [], [], [], [], [], [], []
|
||||||
|
for i, data in tqdm(enumerate(test_dataset)):
|
||||||
|
save_task_response_filename = args.model_name_or_path + f"/results/{mark}_{i}_insert_mistral.json"
|
||||||
|
if os.path.exists(save_task_response_filename):
|
||||||
|
with open(save_task_response_filename, 'r') as f:
|
||||||
|
save_dict = json.load(f)
|
||||||
|
else:
|
||||||
|
prompt = data["prompt"]
|
||||||
|
task = data["task"]
|
||||||
|
|
||||||
|
save_dict = {}
|
||||||
|
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
|
||||||
|
generated_ids = model.generate(**model_inputs,max_new_tokens=1024, do_sample=False, top_p= 0.95, repetition_penalty=1.2)
|
||||||
|
pred = tokenizer.batch_decode(generated_ids)[0]
|
||||||
|
response = pred.split("[SUMMARY]")[-1].replace('</s>','').strip()
|
||||||
|
|
||||||
|
rouge_calc = rouge.compute(predictions = [response], references=[[task]], use_aggregator=True)
|
||||||
|
sacrebleu_calc = sacrebleu.compute(predictions = [response], references=[[task]])
|
||||||
|
meteor_calc = meteor.compute(predictions = [response], references=[[task]])
|
||||||
|
GT_Embedding= embedding_model.encode(task.lower(), convert_to_tensor=True)
|
||||||
|
Prediction_Embedding = embedding_model.encode(response.lower(), convert_to_tensor=True)
|
||||||
|
cosine_similarity = util.cos_sim(GT_Embedding, Prediction_Embedding).item()
|
||||||
|
euclidean_disance = torch.sqrt(torch.sum(torch.pow(torch.subtract(GT_Embedding, Prediction_Embedding), 2))).item()
|
||||||
|
save_dict["prompt"] = prompt
|
||||||
|
save_dict["prediction"] = response
|
||||||
|
save_dict["task"] = task
|
||||||
|
save_dict["sacrebleu"] = sacrebleu_calc
|
||||||
|
save_dict["rouge"] = rouge_calc
|
||||||
|
save_dict["meteor"] = meteor_calc
|
||||||
|
save_dict["cosine_similarity"] = cosine_similarity
|
||||||
|
save_dict["euclidean_disance"] = euclidean_disance
|
||||||
|
|
||||||
|
with open(save_task_response_filename, 'w') as f:
|
||||||
|
json.dump(save_dict, f)
|
||||||
|
|
||||||
|
global_sacrebleu.append(save_dict["sacrebleu"]["score"])
|
||||||
|
global_rouge1.append(save_dict["rouge"]["rouge1"])
|
||||||
|
global_rouge2.append(save_dict["rouge"]["rouge2"])
|
||||||
|
global_rougeL.append(save_dict["rouge"]["rougeL"])
|
||||||
|
global_rougeLsum.append(save_dict["rouge"]["rougeLsum"])
|
||||||
|
global_meteor.append(save_dict["meteor"]["meteor"])
|
||||||
|
global_cosine.append(save_dict["cosine_similarity"])
|
||||||
|
global_distance.append(save_dict["euclidean_disance"])
|
||||||
|
|
||||||
|
return global_sacrebleu, global_rouge1, global_rouge2, global_rougeL, global_rougeLsum, global_meteor, global_cosine, global_distance
|
||||||
|
|
||||||
|
def main(mark):
|
||||||
|
args = initialize()
|
||||||
|
assert 'Mind2Web' in args.test_data_dir
|
||||||
|
tokenizer, model = setup_model_and_tokenizer(args)
|
||||||
|
sacrebleu = evaluate.load('sacrebleu', modeule_type = "metric")
|
||||||
|
rouge = evaluate.load('rouge', modeule_type = "metric")
|
||||||
|
meteor = evaluate.load('meteor', modeule_type = "metric")
|
||||||
|
embedding_model = SentenceTransformer(args.embedding_model_path, device="cuda")
|
||||||
|
|
||||||
|
test_folders_names = ["test_domain", "test_task", "test_website"]
|
||||||
|
for name in test_folders_names:
|
||||||
|
test_folder_path = Path(os.path.join(args.test_data_dir,name))
|
||||||
|
global_sacrebleu, global_rouge1, global_rouge2, global_rougeL, global_rougeLsum, global_meteor, global_cosine, global_distance = [], [], [], [], [], [], [], []
|
||||||
|
for json_file in test_folder_path.rglob('*_with_steps_insert_mistral.json'):
|
||||||
|
with json_file.open('r') as f:
|
||||||
|
data = json.load(f)
|
||||||
|
raw_tasks = load_raw_dataset(data, args)
|
||||||
|
sacrebleu_calc, rouge1_calc, rouge2_calc, rougeL_calc, rougeLsum_calc, meteor_calc, cosine_calc, distance_calc = main_loop(args, raw_tasks, tokenizer, model, sacrebleu, rouge, meteor, embedding_model, 'test_%s'%(name))
|
||||||
|
|
||||||
|
global_sacrebleu.extend(sacrebleu_calc)
|
||||||
|
global_rouge1.extend(rouge1_calc)
|
||||||
|
global_rouge2.extend(rouge2_calc)
|
||||||
|
global_rougeL.extend(rougeL_calc)
|
||||||
|
global_rougeLsum.extend(rougeLsum_calc)
|
||||||
|
global_meteor.extend(meteor_calc)
|
||||||
|
global_cosine.extend(cosine_calc)
|
||||||
|
global_distance.extend(distance_calc)
|
||||||
|
|
||||||
|
print(mark, name)
|
||||||
|
print("%.3f" % (np.mean(global_cosine)))
|
||||||
|
print("%.3f" % (np.mean(global_sacrebleu)/100.0))
|
||||||
|
print("%.3f" % (np.mean(global_rougeL)))
|
||||||
|
print("%.3f" % (np.mean(global_meteor)))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main('test')
|
30
inference/inference.sh
Normal file
30
inference/inference.sh
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
PROJECT_PATH="your-project-path"
|
||||||
|
EMBEDDING_MODEL_PATH="${PROJECT_PATH}/sentence-transformer/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/e4ce9877abf3edfe10b0d82785e83bdcb973e22e"
|
||||||
|
|
||||||
|
OPTS=""
|
||||||
|
OPTS+=" --embedding_model_path ${EMBEDDING_MODEL_PATH}"
|
||||||
|
OPTS+=" --test_data_dir ${PROJECT_PATH}/data/Mind2Web/test"
|
||||||
|
OPTS+=" --train_data_dir ${PROJECT_PATH}/data/Mind2Web/train/train_with_steps_insert_mistral.json"
|
||||||
|
OPTS+=" --prompt_file ${PROJECT_PATH}/prompts/summarisation/summarisation_prompt.txt"
|
||||||
|
|
||||||
|
MODEL_NAME_OR_PATH_BMT="${PROJECT_PATH}/ckpts/experiment/epoch_14"
|
||||||
|
MODEL_NAME_OR_PATH_HF="${MODEL_NAME_OR_PATH_BMT}-hf"
|
||||||
|
MODEL_NAME_OR_PATH_ORIGINAL_MISTRAL="${PROJECT_PATH}/Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24"
|
||||||
|
|
||||||
|
# Convert the model to HF format
|
||||||
|
if [ ! -f "${MODEL_NAME_OR_PATH_HF}/config.json" ]; then
|
||||||
|
CMD="python3 ${PROJECT_PATH}/hf_bmt/bmt_hf.py --in_path ${MODEL_NAME_OR_PATH_BMT} --output_path ${MODEL_NAME_OR_PATH_HF} --original_mistral_path ${MODEL_NAME_OR_PATH_ORIGINAL_MISTRAL}"
|
||||||
|
echo "-------BMT -> HF CMD is------"
|
||||||
|
echo "CMD: ${CMD}"
|
||||||
|
echo "-------BMT -> HF CMD end------"
|
||||||
|
eval ${CMD}
|
||||||
|
fi
|
||||||
|
|
||||||
|
OPTS+=" --model_name_or_path ${MODEL_NAME_OR_PATH_HF}"
|
||||||
|
|
||||||
|
CMD="python3 inference.py ${OPTS}"
|
||||||
|
|
||||||
|
echo "-------final CMD is------"
|
||||||
|
echo "${CMD}"
|
||||||
|
echo "-------final CMD end------"
|
||||||
|
eval ${CMD}
|
135
preprocess/convert_dataset.py
Normal file
135
preprocess/convert_dataset.py
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
import os, pdb
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
from enum import Enum
|
||||||
|
from tqdm import tqdm
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
def read_json_file(filename):
|
||||||
|
with open(filename, 'r') as infile:
|
||||||
|
data = json.load(infile)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def convert_string(string_or_list):
|
||||||
|
# Add escaping symbols to English quotes in string
|
||||||
|
if isinstance(string_or_list, str):
|
||||||
|
return string_or_list.replace('"', '\\"')
|
||||||
|
elif isinstance(string_or_list, list):
|
||||||
|
return [convert_string(s) for s in string_or_list]
|
||||||
|
|
||||||
|
def is_visible(element):
|
||||||
|
bounding_box = element.get('bounding_box_rect')
|
||||||
|
return bounding_box != "-1,-1,-1,-1"
|
||||||
|
|
||||||
|
def clean_text(text):
|
||||||
|
cleaned_text = text.strip()
|
||||||
|
cleaned_text = cleaned_text.replace('\n', ' ').replace('\t', ' ')
|
||||||
|
cleaned_text = re.sub(' +', ' ', cleaned_text)
|
||||||
|
return cleaned_text
|
||||||
|
|
||||||
|
def find_semantic_info(element):
|
||||||
|
element_text = clean_text(element.get_text(strip=True))
|
||||||
|
if element_text:
|
||||||
|
return element_text
|
||||||
|
|
||||||
|
label = element.find_previous(lambda x: x.name == 'label' and is_visible(x))
|
||||||
|
if label:
|
||||||
|
label_text = clean_text(label.get_text(strip=True))
|
||||||
|
if label_text:
|
||||||
|
return label_text
|
||||||
|
return None
|
||||||
|
|
||||||
|
def action_discription(ui_element_name, ui_element_text, operation_type, value):
|
||||||
|
ret_en = ""
|
||||||
|
if operation_type == "TYPE":
|
||||||
|
if ui_element_text != "":
|
||||||
|
ret_en += f'Type text "{value}" into {ui_element_name} with text "{ui_element_text}" on it'
|
||||||
|
else:
|
||||||
|
ret_en += f'Type text "{value}" into {ui_element_name}'
|
||||||
|
elif operation_type == "SELECT":
|
||||||
|
if ui_element_text != "":
|
||||||
|
ret_en += f'Select "{value}" from {ui_element_name} with text "{ui_element_text}" on it'
|
||||||
|
else:
|
||||||
|
ret_en += f'Select "{value}" from {ui_element_name}.'
|
||||||
|
elif operation_type == "CLICK":
|
||||||
|
if ui_element_text != "":
|
||||||
|
ret_en += f'Click the {ui_element_name} element with text "{ui_element_text}" on it'
|
||||||
|
else:
|
||||||
|
ret_en += f'Click the {ui_element_name} element'
|
||||||
|
return ret_en
|
||||||
|
|
||||||
|
def process_one_task(task):
|
||||||
|
base_info = {
|
||||||
|
"website_en": task["website"],
|
||||||
|
"domain_en": task["domain"],
|
||||||
|
"subdomain_en": task["subdomain"],
|
||||||
|
"annotation_id":task["annotation_id"],
|
||||||
|
"task_description": task["confirmed_task"],
|
||||||
|
"action_reprs" : task["action_reprs"]
|
||||||
|
}
|
||||||
|
action_descriptions_en = []
|
||||||
|
for action_index, action in enumerate(task["actions"]):
|
||||||
|
action_repr = task["action_reprs"][action_index]
|
||||||
|
ui_element, _ = action_repr.split(" -> ")
|
||||||
|
assert ui_element.count("] ")==1
|
||||||
|
ui_element_name, ui_element_text = ui_element.split("] ")
|
||||||
|
ui_element_name = ui_element_name[1:]
|
||||||
|
ui_element_text = ui_element_text.strip()
|
||||||
|
|
||||||
|
if ui_element_text == "":
|
||||||
|
raw_html = action["raw_html"]
|
||||||
|
soup2 = BeautifulSoup(raw_html, 'html.parser')
|
||||||
|
selected_element2 = soup2.find(attrs={"data_pw_testid_buckeye": action["action_uid"]})
|
||||||
|
|
||||||
|
ui_element_text = find_semantic_info(selected_element2)
|
||||||
|
if ui_element_text is not None:
|
||||||
|
ui_element_text = clean_text(ui_element_text)
|
||||||
|
task["action_reprs"][action_index] = f"[{ui_element_name}] {ui_element_text} -> {task['action_reprs'][action_index].split(' -> ')[1]}"
|
||||||
|
else:
|
||||||
|
print(f'Warning: {task["annotation_id"]}, can not find semantic info for {action["action_uid"]}')
|
||||||
|
|
||||||
|
action_description_en = action_discription(ui_element_name, ui_element_text, action["operation"]["op"], action["operation"]["value"])
|
||||||
|
action_descriptions_en.append(action_description_en)
|
||||||
|
|
||||||
|
base_info["task_subintention"] = action_descriptions_en
|
||||||
|
return base_info
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
for foldername in ['train','test_domain','test_website','test_task']:
|
||||||
|
SAVE_PATH = f"your-path-to-data/{foldername}"
|
||||||
|
|
||||||
|
for idx in range(100):
|
||||||
|
savejsonfilename = os.path.join(SAVE_PATH,f'{foldername}_{idx}_with_actions_description_insert.json')
|
||||||
|
if os.path.exists(savejsonfilename):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
jsonfilename = f"{SAVE_PATH}/{foldername}_{idx}.json"
|
||||||
|
if not os.path.exists(jsonfilename):
|
||||||
|
break
|
||||||
|
dataset = read_json_file(jsonfilename)
|
||||||
|
|
||||||
|
Mind2Web_with_subintentions = []
|
||||||
|
for task in tqdm(dataset):
|
||||||
|
base_info = process_one_task(task)
|
||||||
|
Mind2Web_with_subintentions.append(base_info)
|
||||||
|
assert len(Mind2Web_with_subintentions) == len(dataset)
|
||||||
|
|
||||||
|
if 'test' in foldername:
|
||||||
|
with open(os.path.join(SAVE_PATH,f'{foldername}_{idx}_with_actions_description.json'), 'r') as json_file:
|
||||||
|
Mind2Web_with_subintentions_saved = json.load(json_file)
|
||||||
|
|
||||||
|
for i in range(len(Mind2Web_with_subintentions)):
|
||||||
|
if i>=len(Mind2Web_with_subintentions_saved):
|
||||||
|
break
|
||||||
|
if Mind2Web_with_subintentions[i] != Mind2Web_with_subintentions_saved[i]:
|
||||||
|
for key in Mind2Web_with_subintentions[i].keys():
|
||||||
|
if Mind2Web_with_subintentions[i][key] != Mind2Web_with_subintentions_saved[i][key]:
|
||||||
|
found = False
|
||||||
|
for j in range(len(Mind2Web_with_subintentions_saved)):
|
||||||
|
if Mind2Web_with_subintentions[i][key] == Mind2Web_with_subintentions_saved[j][key]:
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
if not found:
|
||||||
|
print(found, i, j, jsonfilename)
|
||||||
|
with open(savejsonfilename, 'w') as json_file:
|
||||||
|
json.dump(Mind2Web_with_subintentions, json_file)
|
97
preprocess/create_steps.py
Normal file
97
preprocess/create_steps.py
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
from tqdm import tqdm
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||||
|
|
||||||
|
def get_tokenizer(model_name_or_path):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, device_map={"":0})
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.padding_side = 'left'
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def get_model(model_name_or_path):
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map={"":0})
|
||||||
|
return model
|
||||||
|
|
||||||
|
def read_json_file(filename):
|
||||||
|
with open(filename, 'r') as infile:
|
||||||
|
data = json.load(infile)
|
||||||
|
return data
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model_name_or_path = "Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24"
|
||||||
|
tokenizer = get_tokenizer(model_name_or_path)
|
||||||
|
model = get_model(model_name_or_path)
|
||||||
|
|
||||||
|
# load prompts
|
||||||
|
with open("your-path-to-data/train_prompt.txt", "r") as f:
|
||||||
|
train_prompt = f.read()
|
||||||
|
with open("your-path-to-data/test_prompt.txt", "r") as f:
|
||||||
|
test_prompt = f.read()
|
||||||
|
|
||||||
|
for foldername in ['train','test_domain','test_website','test_task']:
|
||||||
|
SAVE_PATH = f"your-path-to-data/{foldername}"
|
||||||
|
|
||||||
|
for idx in range(100):
|
||||||
|
savejsonfilename = f"{SAVE_PATH}/{foldername}_{idx}_with_steps_insert_mistral.json"
|
||||||
|
jsonfilename = f"{SAVE_PATH}/{foldername}_{idx}_with_actions_description_insert.json"
|
||||||
|
if not os.path.exists(jsonfilename):
|
||||||
|
break
|
||||||
|
|
||||||
|
data = read_json_file(jsonfilename)
|
||||||
|
if os.path.exists(savejsonfilename):
|
||||||
|
data = read_json_file(savejsonfilename)
|
||||||
|
actions_steps = []
|
||||||
|
for i in tqdm(range(len(data)), desc="Steps_Creation"):
|
||||||
|
if "train" in foldername: # include task
|
||||||
|
message = f"""Website: {data[i]["website_en"]}
|
||||||
|
Domain: {data[i]["domain_en"]}
|
||||||
|
Sub-domain: {data[i]["subdomain_en"]}
|
||||||
|
Task: {data[i]["task_description"]}
|
||||||
|
Actions: {data[i]["task_subintention"]}\n
|
||||||
|
# OUTPUT #
|
||||||
|
"""
|
||||||
|
prompt = train_prompt
|
||||||
|
else: # exclude task
|
||||||
|
message = f"""Website: {data[i]["website_en"]}
|
||||||
|
Domain: {data[i]["domain_en"]}
|
||||||
|
Sub-domain: {data[i]["subdomain_en"]}
|
||||||
|
Actions: {data[i]["task_subintention"]}\n
|
||||||
|
# OUTPUT #
|
||||||
|
"""
|
||||||
|
prompt = test_prompt
|
||||||
|
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": prompt},
|
||||||
|
{"role": "user", "content": message}
|
||||||
|
]
|
||||||
|
messages = 'System: ' + prompt + 'User: ' + message
|
||||||
|
|
||||||
|
model_inputs = tokenizer(messages, return_tensors="pt").to("cuda")
|
||||||
|
assert len(model_inputs['input_ids'])<=4096
|
||||||
|
generated_ids = model.generate(**model_inputs,max_new_tokens=1024, do_sample=False, top_p= 0.95, repetition_penalty=1.2)
|
||||||
|
json_object = tokenizer.batch_decode(generated_ids)[0]
|
||||||
|
answer = json_object.split('Sub-intentions: [')[-1].split('\n')
|
||||||
|
final_answer = []
|
||||||
|
for a in answer:
|
||||||
|
a = a.strip()
|
||||||
|
if '</s>' in a:
|
||||||
|
a = a.split('</s>')[0]
|
||||||
|
if len(a)==0:
|
||||||
|
continue
|
||||||
|
while a[0]=='"':
|
||||||
|
a = a[1:]
|
||||||
|
if len(a)==0:
|
||||||
|
break
|
||||||
|
if len(a)==0:
|
||||||
|
continue
|
||||||
|
while a[-1] in ['"', ',', ']', ]:
|
||||||
|
a = a[:-1]
|
||||||
|
if len(a)==0:
|
||||||
|
break
|
||||||
|
if len(a)==0:
|
||||||
|
continue
|
||||||
|
final_answer.append(a)
|
||||||
|
data[i]['steps'] = final_answer
|
||||||
|
with open(savejsonfilename, 'w') as json_file:
|
||||||
|
json.dump(data, json_file)
|
374
train/train.py
Normal file
374
train/train.py
Normal file
|
@ -0,0 +1,374 @@
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import bmtrain as bmt
|
||||||
|
from functools import partial
|
||||||
|
import time
|
||||||
|
import os, pdb, shutil
|
||||||
|
import random
|
||||||
|
import json
|
||||||
|
from model_center.model import Llama
|
||||||
|
from model_center.tokenizer import LlamaTokenizer
|
||||||
|
from functools import partial
|
||||||
|
from dataset_wrapper import PromptIterableDataset, collator
|
||||||
|
import wandb
|
||||||
|
import csv
|
||||||
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
from sentence_transformers import SentenceTransformer, util
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
def get_tokenizer(args):
|
||||||
|
tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path)
|
||||||
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
tokenizer.padding_side = 'left'
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
def get_model(args):
|
||||||
|
model = Llama.from_pretrained(args.model_name_or_path)
|
||||||
|
if args.load_ckpt is not None:
|
||||||
|
logger.info(f"loading model from {args.load_ckpt}")
|
||||||
|
bmt.load(model, os.path.join(args.load_ckpt, "pytorch_model.pt"))
|
||||||
|
return model
|
||||||
|
|
||||||
|
def get_optimizer(args, model):
|
||||||
|
optimizer = bmt.optim.AdamOffloadOptimizer(
|
||||||
|
model.parameters(),
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
eps=1e-5,
|
||||||
|
betas=(0.9, 0.95)
|
||||||
|
)
|
||||||
|
if args.load_ckpt is not None:
|
||||||
|
file_name = os.path.join(args.load_ckpt, "optim.rank-{}.opt".format(bmt.rank()))
|
||||||
|
logger.info(file_name)
|
||||||
|
if os.path.exists(file_name):
|
||||||
|
logger.info("start to load gradient ckpt {}".format(file_name))
|
||||||
|
states = torch.load(file_name)
|
||||||
|
optimizer.load_state_dict(states)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
def get_learning_rate_scheduler(args, optimizer):
|
||||||
|
if args.lr_decay_iters is None:
|
||||||
|
args.lr_decay_iters = args.train_iters
|
||||||
|
if args.lr_decay_style == "linear":
|
||||||
|
lr_scheduler = bmt.lr_scheduler.Linear(
|
||||||
|
optimizer,
|
||||||
|
start_lr=args.lr,
|
||||||
|
warmup_iter=int(args.warmup_iters),
|
||||||
|
end_iter=args.lr_decay_iters,
|
||||||
|
num_iter=args.start_step,
|
||||||
|
)
|
||||||
|
elif args.lr_decay_style == "cosine":
|
||||||
|
bmt.print_rank("use cosine")
|
||||||
|
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
|
||||||
|
def get_lr_warmup(self, num_iter) -> float:
|
||||||
|
return self.start_lr * num_iter / self.warmup_iter
|
||||||
|
|
||||||
|
def get_lr_decay(self, num_iter) -> float:
|
||||||
|
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
|
||||||
|
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
|
||||||
|
|
||||||
|
lr_scheduler = Cosine(
|
||||||
|
optimizer,
|
||||||
|
start_lr=args.lr,
|
||||||
|
warmup_iter=int(args.warmup_iters),
|
||||||
|
end_iter=args.lr_decay_iters,
|
||||||
|
num_iter=args.start_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif args.lr_decay_style == "noam":
|
||||||
|
logger.info("use noam")
|
||||||
|
lr_scheduler = bmt.lr_scheduler.Noam(
|
||||||
|
optimizer,
|
||||||
|
start_lr=args.lr,
|
||||||
|
warmup_iter=int(args.warmup_iters),
|
||||||
|
end_iter=args.lr_decay_iters,
|
||||||
|
num_iter=args.start_step,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return lr_scheduler
|
||||||
|
|
||||||
|
def setup_model_and_optimizer(args):
|
||||||
|
# get the tokenizer
|
||||||
|
tokenizer = get_tokenizer(args)
|
||||||
|
# get the model
|
||||||
|
model = get_model(args)
|
||||||
|
bmt.synchronize()
|
||||||
|
# get the optimizer and lr_scheduler
|
||||||
|
optimizer = get_optimizer(args, model)
|
||||||
|
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
|
||||||
|
bmt.synchronize()
|
||||||
|
return tokenizer, model, optimizer, lr_scheduler
|
||||||
|
|
||||||
|
def initialize():
|
||||||
|
parser = argparse.ArgumentParser("")
|
||||||
|
# model training arguments
|
||||||
|
parser.add_argument("--lr", type=float, default=1e-5)
|
||||||
|
parser.add_argument("--model_name_or_path")
|
||||||
|
parser.add_argument("--epochs", type=int, default=1)
|
||||||
|
parser.add_argument("--seed", type=int, default=0)
|
||||||
|
parser.add_argument("--max_seq_length", default=2048, type=int)
|
||||||
|
parser.add_argument("--batch_size_per_device", default=2, type=int)
|
||||||
|
parser.add_argument("--logging_step", default=100, type=int)
|
||||||
|
parser.add_argument("--save_step", default=50000, type=int)
|
||||||
|
parser.add_argument("--gradient_accumulation_steps", default=1, type=int)
|
||||||
|
parser.add_argument("--wandb", default= True ,action="store_true")
|
||||||
|
parser.add_argument("--with_eval", action="store_true")
|
||||||
|
parser.add_argument("--clip_grad", type=float, default=1.0, help="gradient clipping")
|
||||||
|
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay rate")
|
||||||
|
parser.add_argument("--loss_scale", type=float, default=6553600, help="loss scale")
|
||||||
|
parser.add_argument("--train_iters", type=int, default=2000000)
|
||||||
|
|
||||||
|
# loss parameters
|
||||||
|
parser.add_argument("--action_weight", type=float, help="weight of the tokens that match the action")
|
||||||
|
parser.add_argument("--embedding_model_path", type=str, help="The path to the sentence embedding model")
|
||||||
|
|
||||||
|
# data parameters
|
||||||
|
parser.add_argument('--data_setting', type=str ,help='MTSD or MTMD', default="MTMD")
|
||||||
|
parser.add_argument('--data_dir', type=str, help='The directory for saving the dataset')
|
||||||
|
parser.add_argument('--max_train_samples', type=int, help='The maximum number of training samples')
|
||||||
|
|
||||||
|
parser.add_argument('--cache_dir', type=str, help='The directory for cache')
|
||||||
|
parser.add_argument("--save_dir", type=str, default="")
|
||||||
|
|
||||||
|
parser.add_argument("--save_limit", type=int, default=None, help="ckpt saved limit number")
|
||||||
|
|
||||||
|
parser.add_argument("--warmup_iters", type=int, default=1000)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr_decay_style",
|
||||||
|
type=str,
|
||||||
|
default="cosine",
|
||||||
|
choices=["constant", "linear", "cosine", "exponential", "noam"],
|
||||||
|
help="learning rate decay function",
|
||||||
|
)
|
||||||
|
parser.add_argument("--lr_decay_iters", type=int, default=None, help="lr decay steps")
|
||||||
|
parser.add_argument("--start_step", type=int, default=0, help="step to start or continue training")
|
||||||
|
parser.add_argument("--load_ckpt", type=str, default=None, help="resumed ckpt")
|
||||||
|
parser.add_argument("--save_processed_data", action='store_true', help="wheather or no save the processed data")
|
||||||
|
parser.add_argument("--prompt_file", type=str, default=None, help="The file for loading the prompt")
|
||||||
|
args = parser.parse_args()
|
||||||
|
# init bmt
|
||||||
|
bmt.init_distributed(seed=args.seed)
|
||||||
|
set_seed(args.seed)
|
||||||
|
# wandb
|
||||||
|
if args.wandb and bmt.rank() == 0:
|
||||||
|
wandb.init(project='Mistral-Interact', config=args, name=args.save_dir.split('Mistral-7b/')[1][:-1], save_code=True, settings=wandb.Settings(code_dir="."))
|
||||||
|
return args
|
||||||
|
|
||||||
|
def format_one_action(action):
|
||||||
|
return f"- {action}\n"
|
||||||
|
|
||||||
|
def format_actions_list(actions):
|
||||||
|
actions_str = ""
|
||||||
|
for action in actions:
|
||||||
|
actions_str += format_one_action(action)
|
||||||
|
return actions_str
|
||||||
|
|
||||||
|
def read_json_file(filename):
|
||||||
|
with open(filename, 'r') as infile:
|
||||||
|
data = json.load(infile)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def load_Mind2Web_dataset(args, save_dataset= False):
|
||||||
|
# read text from a file (file name is args.prompt_file)
|
||||||
|
with open(args.prompt_file, 'r') as file:
|
||||||
|
task_description = file.read().split('===')
|
||||||
|
raw_dataset = read_json_file(args.data_dir)
|
||||||
|
|
||||||
|
dataset=[]
|
||||||
|
for idx, d in enumerate(raw_dataset):
|
||||||
|
sequences = []
|
||||||
|
input_str = f"## Website:\n{d['website_en']}\n\n## Domain:\n{d['domain_en']}\n\n## Sub-domain:\n{d['subdomain_en']}\n\n## Actions (Each line is one action):\n{format_actions_list(d['task_subintention'])}\n## Sub-intentions summarised from these actions:\n{format_actions_list(d['steps'])}"
|
||||||
|
|
||||||
|
query_inputs = f"{task_description[0]}\n{input_str}{task_description[1]}\n"
|
||||||
|
sequences.append(query_inputs)
|
||||||
|
summary_str = d['task_description']
|
||||||
|
summary_str = "[SUMMARY] " + summary_str[0].upper() + summary_str[1:]
|
||||||
|
sequences.append(summary_str)
|
||||||
|
dataset.append({"data": sequences.copy()})
|
||||||
|
|
||||||
|
random.shuffle(dataset)
|
||||||
|
|
||||||
|
if args.max_train_samples is not None:
|
||||||
|
dataset = dataset[:args.max_train_samples]
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def load_MoTIF_dataset(args, save_dataset= False):
|
||||||
|
with open(args.prompt_file, 'r') as file:
|
||||||
|
task_description = file.read().split('===')
|
||||||
|
|
||||||
|
raw_dataset = []
|
||||||
|
for filename in os.listdir(args.data_dir):
|
||||||
|
if filename.endswith('_steps.json'):
|
||||||
|
file_path = os.path.join(args.data_dir, filename)
|
||||||
|
with open(file_path, 'r', encoding='utf-8') as json_file:
|
||||||
|
try:
|
||||||
|
content = json.load(json_file)
|
||||||
|
raw_dataset.append(content)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
raise ValueError(f"Error decoding JSON from file {filename}: {e}")
|
||||||
|
|
||||||
|
dataset=[]
|
||||||
|
for d in raw_dataset:
|
||||||
|
sequences = []
|
||||||
|
input_str = f"## Application:\n{d['app']}\n\n## Actions (Each line is one action):\n{format_actions_list(d['instr'])}\n## Sub-intentions summarised from these actions:\n{format_actions_list(d['steps'])}"
|
||||||
|
query_inputs = f"{task_description[0]}\n{input_str}{task_description[1]}\n"
|
||||||
|
sequences.append(query_inputs)
|
||||||
|
summary_str = d['goal']
|
||||||
|
summary_str = "[SUMMARY] " + summary_str[0].upper() + summary_str[1:]
|
||||||
|
sequences.append(summary_str)
|
||||||
|
dataset.append({"data": sequences.copy()})
|
||||||
|
|
||||||
|
random.shuffle(dataset)
|
||||||
|
|
||||||
|
if args.max_train_samples is not None:
|
||||||
|
dataset = dataset[:args.max_train_samples]
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset):
|
||||||
|
embedding_model = SentenceTransformer(args.embedding_model_path, device="cuda")
|
||||||
|
for param in embedding_model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
logger.info(f"total training instance number: {len(dataset)}")
|
||||||
|
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
|
||||||
|
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale)
|
||||||
|
optim_manager.add_optimizer(optimizer, lr_scheduler)
|
||||||
|
bmt.synchronize()
|
||||||
|
|
||||||
|
avg_time_recorder = bmt.utils.AverageRecorder()
|
||||||
|
avg_loss_recorder = bmt.utils.AverageRecorder()
|
||||||
|
train_start_time = time.time()
|
||||||
|
global_step = 0
|
||||||
|
|
||||||
|
logger.info("split data for each process")
|
||||||
|
data_per_gpu = len(dataset) // bmt.world_size()
|
||||||
|
dataset = dataset[bmt.rank() * data_per_gpu: (bmt.rank() + 1) * data_per_gpu]
|
||||||
|
bmt.print_rank("training on [%d, %d] of the dataset" % (bmt.rank() * data_per_gpu, (bmt.rank() + 1) * data_per_gpu))
|
||||||
|
dataset = PromptIterableDataset(
|
||||||
|
dataset,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
max_seq_length=args.max_seq_length,
|
||||||
|
teacher_forcing=True,
|
||||||
|
truncate_method="tail",
|
||||||
|
)
|
||||||
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
print(f"Total trainable parameters: {total_params}")
|
||||||
|
|
||||||
|
total_params = sum(p.numel() for p in model.parameters())
|
||||||
|
print(f"Total parameters: {total_params}")
|
||||||
|
|
||||||
|
for epoch in range(args.epochs):
|
||||||
|
savefolder = os.path.join(args.save_dir, f"epoch_{epoch}")
|
||||||
|
os.makedirs(savefolder, exist_ok=True)
|
||||||
|
|
||||||
|
dataloader = DataLoader(dataset, batch_size=args.batch_size_per_device)
|
||||||
|
|
||||||
|
progress_bar = tqdm(range(len(dataloader)), disable=not bmt.rank()==0, desc=f"epoch {epoch}")
|
||||||
|
logger.info(f"*******start {epoch} epoch training********")
|
||||||
|
for step, inputs in enumerate(dataloader):
|
||||||
|
if global_step < args.start_step:
|
||||||
|
global_step += 1
|
||||||
|
progress_bar.update(1)
|
||||||
|
continue
|
||||||
|
st = time.time()
|
||||||
|
|
||||||
|
with bmt.inspect.inspect_tensor() as inspector:
|
||||||
|
for k in inputs:
|
||||||
|
inputs[k] = inputs[k].cuda()
|
||||||
|
|
||||||
|
labels = inputs.pop("labels")
|
||||||
|
weight_idxs = inputs.pop('weight_idxs')
|
||||||
|
logits = model(**inputs).logits
|
||||||
|
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
|
||||||
|
# Flatten the tokens
|
||||||
|
shift_logits = shift_logits.view(-1, len(tokenizer))
|
||||||
|
shift_labels = shift_labels.view(-1).to(shift_logits.device)
|
||||||
|
ntp_loss = loss_func(shift_logits, shift_labels)
|
||||||
|
|
||||||
|
sample_specific_weights = torch.ones_like(shift_logits)
|
||||||
|
weight_idxs = weight_idxs[:, 1:, :].contiguous()
|
||||||
|
weight_idxs = weight_idxs.view(-1, weight_idxs.size(-1))
|
||||||
|
assert weight_idxs.shape[0] == sample_specific_weights.shape[0], "310"
|
||||||
|
sample_specific_weights[weight_idxs==1] = args.action_weight
|
||||||
|
sample_specific_weights = sample_specific_weights[torch.arange(sample_specific_weights.size(0)), shift_labels]
|
||||||
|
|
||||||
|
ntp_loss = (ntp_loss * sample_specific_weights).mean()
|
||||||
|
next_token_loss_item = bmt.sum_loss(ntp_loss).item()
|
||||||
|
|
||||||
|
global_loss = next_token_loss_item
|
||||||
|
optim_manager.backward(ntp_loss)
|
||||||
|
|
||||||
|
if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(dataloader) - 1:
|
||||||
|
optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=args.clip_grad)
|
||||||
|
optim_manager.step()
|
||||||
|
optim_manager.zero_grad()
|
||||||
|
|
||||||
|
global_step += 1
|
||||||
|
progress_bar.update(1)
|
||||||
|
|
||||||
|
# record time and loss
|
||||||
|
iteration_time = time.time() - st
|
||||||
|
|
||||||
|
avg_time_recorder.record(iteration_time)
|
||||||
|
if not np.isnan(global_loss):
|
||||||
|
avg_loss_recorder.record(global_loss)
|
||||||
|
|
||||||
|
# print time and loss
|
||||||
|
if global_step % args.logging_step == 0:
|
||||||
|
bmt.print_rank(
|
||||||
|
"| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} | time: {:.4f} seconds | total_time_passed: {:.4f} minutes".format(
|
||||||
|
global_step,
|
||||||
|
global_loss,
|
||||||
|
avg_loss_recorder.value,
|
||||||
|
lr_scheduler.current_lr,
|
||||||
|
avg_time_recorder.value,
|
||||||
|
(time.time() - train_start_time) / 60
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if args.wandb and bmt.rank() == 0:
|
||||||
|
wandb.log({
|
||||||
|
"loss": global_loss,
|
||||||
|
"next_token_loss": next_token_loss_item,
|
||||||
|
"average_loss": avg_loss_recorder.value,
|
||||||
|
"lr": lr_scheduler.current_lr,
|
||||||
|
}, step=global_step)
|
||||||
|
|
||||||
|
if global_step == args.train_iters:
|
||||||
|
break
|
||||||
|
|
||||||
|
bmt.save(model, os.path.join(savefolder, "pytorch_model.pt"))
|
||||||
|
if bmt.rank() == 0:
|
||||||
|
tokenizer.save_pretrained(savefolder)
|
||||||
|
bmt.print_rank(f"model saved at {savefolder}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = initialize()
|
||||||
|
if "Mind2Web" in args.data_dir:
|
||||||
|
dataset = load_Mind2Web_dataset(args, save_dataset=True)
|
||||||
|
else:
|
||||||
|
assert "MoTIF" in args.data_dir
|
||||||
|
dataset = load_MoTIF_dataset(args, save_dataset=True)
|
||||||
|
args.train_iters = min(args.epochs * (len(dataset) // (bmt.world_size() * args.batch_size_per_device) + 1), args.train_iters)
|
||||||
|
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
|
||||||
|
finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
44
train/train.sh
Executable file
44
train/train.sh
Executable file
|
@ -0,0 +1,44 @@
|
||||||
|
#! /bin/bash
|
||||||
|
MASTER_ADDR=localhost
|
||||||
|
MASTER_PORT=12345
|
||||||
|
NNODES=1
|
||||||
|
NODE_RANK=0
|
||||||
|
GPUS_PER_NODE=2
|
||||||
|
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
|
||||||
|
--nnodes $NNODES \
|
||||||
|
--node_rank $NODE_RANK \
|
||||||
|
--master_addr $MASTER_ADDR \
|
||||||
|
--master_port $MASTER_PORT"
|
||||||
|
|
||||||
|
PROJECT_PATH="your-project-path"
|
||||||
|
|
||||||
|
OPTS=""
|
||||||
|
# model config
|
||||||
|
MAXSEQLEN=1024
|
||||||
|
OPTS+=" --max_seq_length ${MAXSEQLEN}"
|
||||||
|
OPTS+=" --model_name_or_path ${PROJECT_PATH}/Mistral-7b-bmtrain"
|
||||||
|
# training config
|
||||||
|
OPTS+=" --logging_step 4"
|
||||||
|
BATCHSIZE=16
|
||||||
|
OPTS+=" --batch_size_per_device ${BATCHSIZE}"
|
||||||
|
OPTS+=" --save_step 500"
|
||||||
|
OPTS+=" --epochs 15"
|
||||||
|
LR=1e-6
|
||||||
|
OPTS+=" --lr ${LR}"
|
||||||
|
OPTS+=" --warmup_iters 0"
|
||||||
|
OPTS+=" --start_step 0"
|
||||||
|
OPTS+=" --loss_scale 6400"
|
||||||
|
ACTIONWEIGHT=2
|
||||||
|
OPTS+=" --action_weight ${ACTIONWEIGHT}"
|
||||||
|
EMBEDDING_MODEL_PATH="${PROJECT_PATH}/sentence-transformer/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/e4ce9877abf3edfe10b0d82785e83bdcb973e22e"
|
||||||
|
OPTS+=" --embedding_model_path ${EMBEDDING_MODEL_PATH}"
|
||||||
|
|
||||||
|
OPTS+=" --prompt_file ${PROJECT_PATH}/prompts/summarisation/summarisation_prompt.txt"
|
||||||
|
OPTS+=" --save_dir ${PROJECT_PATH}/ckpts/experiment"
|
||||||
|
|
||||||
|
CMD="torchrun ${DISTRIBUTED_ARGS} train.py ${OPTS}"
|
||||||
|
|
||||||
|
echo "-------final CMD is------"
|
||||||
|
echo "${CMD}"
|
||||||
|
echo "-------final CMD end------"
|
||||||
|
${CMD}
|
Loading…
Add table
Add a link
Reference in a new issue