This commit is contained in:
Guanhua Zhang 2025-04-10 20:14:17 +02:00
commit 04c4625cfe
11 changed files with 1330 additions and 0 deletions

95
README.md Normal file
View file

@ -0,0 +1,95 @@
<div align="center">
<h1> SummAct: Uncovering User Intentions Through Interactive Behaviour Summarisation </h1>
**[Guanhua Zhang][4], &nbsp; [Mohamed Ahmed][3], &nbsp; [Zhiming Hu][5], &nbsp; [Andreas Bulling][6]** <br>
**ACM CHI 2025**, Yokohama, Japan <br>
**[[Project][2]]** **[[Paper][7]]** </div>
----------------
# Directory Structure
```
SummAct
│ README.md
│ environment.yml
└───preprocess
│ convert_dataset.py
│ create_steps.py
└───hf_bmt
│ hf_2_bmtrain.py
│ hf_2_bmtrain.sh
│ bmt_hf.py
└───train
│ train.py
│ train.sh
└───inference
│ inference.py
│ inference.sh
└───train
│ train.py
│ train.sh
└───inference
│ inference.py
│ inference.sh
```
# Setup
We recommend setting up a virtual environment using Anaconda. <br>
1. Create a conda environment and install dependencies
```shell
conda env create --name summact --file=env.yaml
conda activate summact
```
2. Since `model_center==1.0.3` is needed but is not yet available on PYPI, build from [source](https://github.com/OpenBMB/ModelCenter)
```
$ git clone https://github.com/OpenBMB/ModelCenter.git
$ cd ModelCenter
$ pip install -r requirements.txt
$ python3 setup.py install
```
3. Clone our repository to download our code and a pretrained model
```shell
git clone this_repo.git
```
# Preprocessing
1. Convert actions from symbolic formats to natural language by running `preprocess/convert_dataset.py`. Adapt it to your local dataset paths.
2. Prompting the pretrained LLM with examples to generate sub-intentions using `preprocess/create_steps.py`. Adapt it to your local prompt txt path.
# Fine-tuning
1. After downloading the model from Hugging Face, convert it into `model_center` weights using the script in `hf_bmt/hf_2_bmtrain.sh`. Adapt it to your local paths of the downloaded model and the wanted output.
2. Run `train/train.sh`, which will call `train/train.py` to fine-tune the model for interactive behaviour summarisation. Make sure your computer has GPUs.
# Inference
Run `inference/inference.sh`, which will call `inference/inference.py` to convert the fine-tuned model back to HF format, and then calculate metrics to evaluate the summarisation quality.
# Citation
If you find our code useful or use it in your own projects, please cite our paper:
```
@inproceedings{zhang25_chi,
title = {SummAct: Uncovering User Intentions Through Interactive Behaviour Summarisation},
author = {Zhang, Guanhua and Ahmed, Mohamed and Hu, Zhiming and Bulling, Andreas},
year = {2025},
pages = {1--17},
booktitle = {Proc. ACM SIGCHI Conference on Human Factors in Computing Systems (CHI)},
doi = {10.1145/3706598.3713190}
}
```
# Acknowledgements
Our work relied on the codebase of [Mind2Web][1], [ScreenAgent][8] and [Tell Me More!][9]. Thanks to the authors for sharing their code.
[1]: https://osu-nlp-group.github.io/Mind2Web/
[2]: https://collaborative-ai.org/publications/zhang25_chi/
[3]: https://www.linkedin.com/in/mohamed-adel-naguib/
[4]: https://scholar.google.com/citations?user=NqkK0GwAAAAJ&hl=en
[5]: https://scholar.google.com/citations?hl=en&user=OLB_xBEAAAAJ
[6]: https://www.collaborative-ai.org/people/bulling/
[7]: https://collaborative-ai.org/publications/zhang25_chi.pdf
[8]: https://github.com/niuzaisheng/ScreenAgent
[9]: https://github.com/OpenBMB/Tell_Me_More

184
environment.yml Normal file
View file

@ -0,0 +1,184 @@
name: Mistral
channels:
- defaults
- conda-forge
dependencies:
- _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- asttokens=2.4.1=pyhd8ed1ab_0
- bzip2=1.0.8=hd590300_5
- ca-certificates=2024.2.2=hbcca054_0
- comm=0.2.2=pyhd8ed1ab_0
- debugpy=1.8.1=py311hb755f60_0
- decorator=5.1.1=pyhd8ed1ab_0
- exceptiongroup=1.2.0=pyhd8ed1ab_2
- executing=2.0.1=pyhd8ed1ab_0
- importlib-metadata=7.1.0=pyha770c72_0
- importlib_metadata=7.1.0=hd8ed1ab_0
- ipykernel=6.29.3=pyhd33586a_0
- ipython=8.24.0=pyh707e725_0
- jedi=0.19.1=pyhd8ed1ab_0
- jupyter_client=8.6.1=pyhd8ed1ab_0
- jupyter_core=5.7.2=py311h38be061_0
- keyutils=1.6.1=h166bdaf_0
- krb5=1.21.2=h659d440_0
- ld_impl_linux-64=2.40=h41732ed_0
- libedit=3.1.20191231=he28a2e2_2
- libexpat=2.6.2=h59595ed_0
- libffi=3.4.2=h7f98852_5
- libgcc-ng=13.2.0=h807b86a_5
- libgomp=13.2.0=h807b86a_5
- libnsl=2.0.1=hd590300_0
- libsodium=1.0.18=h36c2ea0_1
- libsqlite=3.45.2=h2797004_0
- libstdcxx-ng=13.2.0=hc0a3c3a_7
- libuuid=2.38.1=h0b41bf4_0
- libxcrypt=4.4.36=hd590300_1
- libzlib=1.2.13=hd590300_5
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
- ncurses=6.4.20240210=h59595ed_0
- nest-asyncio=1.6.0=pyhd8ed1ab_0
- openssl=3.3.0=hd590300_0
- packaging=24.0=pyhd8ed1ab_0
- parso=0.8.4=pyhd8ed1ab_0
- pexpect=4.9.0=pyhd8ed1ab_0
- pickleshare=0.7.5=py_1003
- pip=24.0=pyhd8ed1ab_0
- platformdirs=4.2.1=pyhd8ed1ab_0
- prompt-toolkit=3.0.42=pyha770c72_0
- psutil=5.9.8=py311h459d7ec_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pygments=2.18.0=pyhd8ed1ab_0
- python=3.11.8=hab00c5b_0_cpython
- python_abi=3.11=4_cp311
- pyzmq=26.0.3=py311h08a0b41_0
- readline=8.2=h8228510_1
- setuptools=69.5.1=pyhd8ed1ab_0
- six=1.16.0=pyh6c4a22f_0
- stack_data=0.6.2=pyhd8ed1ab_0
- tk=8.6.13=noxft_h4845f30_101
- tornado=6.4=py311h459d7ec_0
- traitlets=5.14.3=pyhd8ed1ab_0
- typing_extensions=4.11.0=pyha770c72_0
- wcwidth=0.2.13=pyhd8ed1ab_0
- wheel=0.43.0=pyhd8ed1ab_1
- xz=5.2.6=h166bdaf_0
- zeromq=4.3.5=h75354e8_4
- zipp=3.17.0=pyhd8ed1ab_0
- pip:
- absl-py==2.1.0
- accelerate==0.29.2
- aiohttp==3.9.4
- aiosignal==1.3.1
- annotated-types==0.7.0
- antlr4-python3-runtime==4.9.3
- anyio==4.4.0
- appdirs==1.4.4
- attrs==23.2.0
- beautifulsoup4==4.12.3
- bmtrain==1.0.0
- bs4==0.0.2
- certifi==2024.2.2
- charset-normalizer==3.3.2
- click==8.1.7
- colorama==0.4.6
- cprint==1.2.2
- cython==0.29.37
- datasets==2.18.0
- dill==0.3.8
- distro==1.9.0
- docker-pycreds==0.4.0
- evaluate==0.4.1
- filelock==3.13.4
- frozenlist==1.4.1
- fsspec==2024.2.0
- gitdb==4.0.11
- gitpython==3.1.43
- grpcio==1.62.1
- h11==0.14.0
- hdbscan==0.8.37
- httpcore==1.0.5
- httpx==0.27.0
- huggingface-hub==0.22.2
- hydra-core==1.3.2
- idna==3.7
- jieba==0.42.1
- jinja2==3.1.3
- joblib==1.4.0
- keybert==0.8.5
- levenshtein==0.25.1
- lxml==5.2.1
- markdown==3.6
- markdown-it-py==3.0.0
- markupsafe==2.1.5
- mdurl==0.1.2
- mpmath==1.3.0
- multidict==6.0.5
- multiprocess==0.70.16
- networkx==3.3
- nltk==3.8.1
- numpy==1.26.4
- nvidia-cublas-cu12==12.1.3.1
- nvidia-cuda-cupti-cu12==12.1.105
- nvidia-cuda-nvrtc-cu12==12.1.105
- nvidia-cuda-runtime-cu12==12.1.105
- nvidia-cudnn-cu12==8.9.2.26
- nvidia-cufft-cu12==11.0.2.54
- nvidia-curand-cu12==10.3.2.106
- nvidia-cusolver-cu12==11.4.5.107
- nvidia-cusparse-cu12==12.1.0.106
- nvidia-nccl-cu11==2.21.5
- nvidia-nccl-cu12==2.19.3
- nvidia-nvjitlink-cu12==12.4.127
- nvidia-nvtx-cu12==12.1.105
- omegaconf==2.3.0
- openai==1.36.0
- pandas==2.2.2
- pdb-tools==2.5.0
- pillow==10.3.0
- portalocker==2.8.2
- protobuf==4.25.3
- pyarrow==15.0.2
- pyarrow-hotfix==0.6
- pydantic==2.8.2
- pydantic-core==2.20.1
- python-dateutil==2.9.0.post0
- pytz==2024.1
- pyyaml==6.0.1
- rapidfuzz==3.9.6
- regex==2023.12.25
- requests==2.31.0
- responses==0.18.0
- rich==13.7.1
- rouge-score==0.1.2
- sacrebleu==2.4.2
- safetensors==0.4.3
- scikit-learn==1.4.2
- scipy==1.13.0
- sentence-transformers==2.7.0
- sentencepiece==0.2.0
- sentry-sdk==1.45.0
- setproctitle==1.3.3
- smmap==5.0.1
- sniffio==1.3.1
- soupsieve==2.5
- sympy==1.12
- tabulate==0.9.0
- tensorboard==2.16.2
- tensorboard-data-server==0.7.2
- textblob==0.18.0.post0
- threadpoolctl==3.4.0
- tokenizers==0.15.2
- torch==2.2.2
- torchvision==0.17.2
- tqdm==4.66.2
- transformers==4.39.3
- triton==2.2.0
- tzdata==2024.1
- urllib3==2.2.1
- wandb==0.16.6
- werkzeug==3.0.2
- xxhash==3.4.1
- yarl==1.9.4
prefix: /opt/anaconda3/envs/Mistral

92
hf_bmt/bmt_hf.py Normal file
View file

@ -0,0 +1,92 @@
import os, pdb
import json
import torch
import sys
import shutil
import argparse
from collections import OrderedDict
from transformers import AutoConfig, AutoModelForCausalLM
def transform_to_hf(bmt_model, model_size):
model_hf = OrderedDict()
if 'input_embedding.weight' in bmt_model.keys():
model_hf['model.embed_tokens.weight'] = bmt_model["input_embedding.weight"].contiguous().float()
model_hf['model.norm.weight'] = bmt_model["encoder.output_layernorm.weight"].contiguous().float()
try:
model_hf['lm_head.weight'] = bmt_model['output_projection.weight'].contiguous().float()
except:
model_hf['lm_head.weight'] = bmt_model["input_embedding.weight"].contiguous().float()
else:
model_hf['model.embed_tokens.weight'] = bmt_model["LLM.input_embedding.weight"].contiguous().float()
model_hf['model.norm.weight'] = bmt_model["LLM.encoder.output_layernorm.weight"].contiguous().float()
try:
model_hf['lm_head.weight'] = bmt_model['LLM.output_projection.weight'].contiguous().float()
except:
model_hf['lm_head.weight'] = bmt_model["LLM.input_embedding.weight"].contiguous().float()
if model_size == "7b":
layernum = 32
elif model_size == "13b" or model_size == "13b-2":
layernum = 40
elif model_size == "65b":
layernum = 80
for lnum in range(layernum):
hf_pfx = f"model.layers.{lnum}"
if 'input_embedding.weight' in bmt_model.keys():
bmt_pfx = f"encoder.layers.{lnum}"
else:
bmt_pfx = f"LLM.encoder.layers.{lnum}"
model_hf[f"{hf_pfx}.input_layernorm.weight"] = bmt_model[f"{bmt_pfx}.self_att.layernorm_before_attention.weight"].contiguous().float()
model_hf[f"{hf_pfx}.self_attn.q_proj.weight"] = bmt_model[f"{bmt_pfx}.self_att.self_attention.project_q.weight"].contiguous().float()
model_hf[f"{hf_pfx}.self_attn.k_proj.weight"] = bmt_model[f"{bmt_pfx}.self_att.self_attention.project_k.weight"].contiguous().float()
model_hf[f"{hf_pfx}.self_attn.v_proj.weight"] = bmt_model[f"{bmt_pfx}.self_att.self_attention.project_v.weight"].contiguous().float()
model_hf[f"{hf_pfx}.self_attn.o_proj.weight"] = bmt_model[f"{bmt_pfx}.self_att.self_attention.attention_out.weight"].contiguous().float()
model_hf[f"{hf_pfx}.post_attention_layernorm.weight"] = bmt_model[f"{bmt_pfx}.ffn.layernorm_before_ffn.weight"].contiguous().float()
model_hf[f"{hf_pfx}.mlp.gate_proj.weight"] = bmt_model[f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight"].contiguous().float()
model_hf[f"{hf_pfx}.mlp.up_proj.weight"] = bmt_model[f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight"].contiguous().float()
model_hf[f"{hf_pfx}.mlp.down_proj.weight"] = bmt_model[f"{bmt_pfx}.ffn.ffn.w_out.weight"].contiguous().float()
for key in model_hf:
model_hf[key] = model_hf[key].bfloat16()
return model_hf
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--in_path", type=str)
parser.add_argument("--output_path", type=str)
parser.add_argument("--original_mistral_path", type=str)
args = parser.parse_args()
os.makedirs(args.output_path, exist_ok=True)
print("transforming " + args.in_path)
model_size = "7b"
ckpt = [name for name in os.listdir(args.in_path) if name.endswith(".pt")]
bmt_model = torch.load(os.path.join(args.in_path, ckpt[0]))
hf_state_dict = transform_to_hf(bmt_model, model_size)
print(f"start saving to {args.output_path}")
model_config = AutoConfig.from_pretrained(args.original_mistral_path)
model = AutoModelForCausalLM.from_config(model_config)
model.load_state_dict(hf_state_dict)
for param in model.parameters():
param.data = param.data.to(torch.bfloat16)
model.save_pretrained(args.output_path, safe_serialization=False)
for file_name in ["tokenizer_config.json", "special_tokens_map.json", "tokenizer.model", "tokenizer.json"]:
if os.path.exists(os.path.join(args.in_path, file_name)):
shutil.copy(os.path.join(args.in_path, file_name), os.path.join(args.output_path, file_name))
print("saved huggingface checkpoint")

108
hf_bmt/hf_2_bmtrain.py Normal file
View file

@ -0,0 +1,108 @@
from transformers import LlamaConfig
from transformers import AutoModelForCausalLM
import torch, os
import json
from collections import OrderedDict
import shutil, pdb
import argparse
def initialize():
# get arguments
parser = argparse.ArgumentParser("")
# Output Directory for the bmt train weights.
parser.add_argument("--out_path", type=str, default=f"/Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24")
# Path where you downloaded mistral-7b hugging face weight
parser.add_argument('--in_path', type=str, default=f"/Mistral-{ver}-bmtrain")
args = parser.parse_args()
return args
ver = "7b"
# Change these two
# Output Directory for the bmt train weights.
# outpath = f"/Mistral-{ver}-bmtrain"
# Path where you downloaded mistral-7b hugging face weight
# inpath = f"/Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24"
def convert_weights(args):
hf_config = LlamaConfig.from_pretrained(args.in_path)
config = {
'dim_model': hf_config.hidden_size,
'dim_ff': hf_config.intermediate_size,
'num_layers': hf_config.num_hidden_layers,
'num_heads': hf_config.num_attention_heads,
'num_heads_kv': hf_config.num_key_value_heads,
'dim_head': hf_config.hidden_size // hf_config.num_attention_heads,
'norm_eps': hf_config.rms_norm_eps,
}
os.makedirs(args.out_path, exist_ok=True)
with open(os.path.join(args.out_path, "config.json"), 'w') as f:
json.dump(config, f)
layernum = config['num_layers']
model_hf = OrderedDict()
ckpt_num = None
if 'v0.1' in args.in_path:
prefix = "pytorch_model-"
endtext = ".bin"
else:
prefix = "model-"
endtext = ".safetensors"
for name in os.listdir(args.in_path):
if name.startswith(prefix) and name.endswith(endtext):
ckpt_num =int(name.split(endtext)[0].split('-')[-1])
for i in range(1, ckpt_num + 1):
if 'v0.1' in args.in_path:
part = torch.load(os.path.join(args.in_path, f"pytorch_model-{i:05d}-of-{ckpt_num:05d}.bin"))
else:
from safetensors import safe_open
with safe_open(os.path.join(args.in_path, f"model-{i:05d}-of-{ckpt_num:05d}.safetensors"), framework="pt", device=0) as f:
part = {}
for k in f.keys():
part[k] = f.get_tensor(k)
model_hf.update(part)
out = OrderedDict()
out["input_embedding.weight"] = model_hf['model.embed_tokens.weight'].contiguous()
out["encoder.output_layernorm.weight"] = model_hf['model.norm.weight'].contiguous()
out['output_projection.weight'] = model_hf['lm_head.weight'].contiguous()
for lnum in range(layernum):
hf_pfx = f"model.layers.{lnum}"
bmt_pfx = f"encoder.layers.{lnum}"
out[f"{bmt_pfx}.self_att.layernorm_before_attention.weight"] = model_hf[f"{hf_pfx}.input_layernorm.weight"].contiguous()
out[f"{bmt_pfx}.self_att.self_attention.project_q.weight"] = model_hf[f"{hf_pfx}.self_attn.q_proj.weight"].contiguous()
out[f"{bmt_pfx}.self_att.self_attention.project_k.weight"] = model_hf[f"{hf_pfx}.self_attn.k_proj.weight"].contiguous()
out[f"{bmt_pfx}.self_att.self_attention.project_v.weight"] = model_hf[f"{hf_pfx}.self_attn.v_proj.weight"].contiguous()
out[f"{bmt_pfx}.self_att.self_attention.attention_out.weight"] = model_hf[f"{hf_pfx}.self_attn.o_proj.weight"].contiguous()
out[f"{bmt_pfx}.ffn.layernorm_before_ffn.weight"] = model_hf[f"{hf_pfx}.post_attention_layernorm.weight"].contiguous()
out[f"{bmt_pfx}.ffn.ffn.w_in.w_0.weight"] = model_hf[f"{hf_pfx}.mlp.gate_proj.weight"].contiguous()
out[f"{bmt_pfx}.ffn.ffn.w_in.w_1.weight"] = model_hf[f"{hf_pfx}.mlp.up_proj.weight"].contiguous()
out[f"{bmt_pfx}.ffn.ffn.w_out.weight"] = model_hf[f"{hf_pfx}.mlp.down_proj.weight"].contiguous()
for key in out:
out[key] = out[key].half()
if not os.path.exists(args.out_path):
os.makedirs(args.out_path)
torch.save(out, os.path.join(args.out_path, "pytorch_model.pt"))
for file_name in ["tokenizer_config.json", "special_tokens_map.json", "tokenizer.model", "tokenizer.json"]:
if os.path.exists(os.path.join(args.in_path, file_name)):
shutil.copy(os.path.join(args.in_path, file_name), os.path.join(args.out_path, file_name))
print("BMT weights created sucessfully")
def main():
args = initialize()
convert_weights(args)
if __name__ == "__main__":
main()

13
hf_bmt/hf_2_bmtrain.sh Normal file
View file

@ -0,0 +1,13 @@
IN_PATH="your-path-to-hf-model"
OUT_PATH="your-wanted-path-to-bm-model"
OPTS=""
OPTS+="--in_path ${IN_PATH} "
OPTS+="--out_path ${OUT_PATH}"
CMD="python3 hf_2_bmtrain.py ${OPTS}"
echo "-------final CMD is------"
echo "${CMD}"
echo "-------final CMD end------"
eval ${CMD}

158
inference/inference.py Normal file
View file

@ -0,0 +1,158 @@
from transformers import AutoTokenizer, AutoModelForCausalLM
import argparse
import os, pdb
import numpy as np
import json
from pathlib import Path
import json
from tqdm import tqdm
from cprint import cprint
import evaluate
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
import logging
import torch
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
from sentence_transformers import SentenceTransformer, util
def initialize():
parser = argparse.ArgumentParser("")
parser.add_argument("--model_name_or_path", type=str, default='')
parser.add_argument("--embedding_model_path", type=str, default="")
parser.add_argument("--train_data_dir", type=str, default='')
parser.add_argument("--test_data_dir", type=str, default='')
parser.add_argument("--prompt_file", type=str, default=None, help="The file for loading the prompt")
args = parser.parse_args()
return args
def get_tokenizer(args):
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, device_map={"":0})
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
return tokenizer
def get_model(args):
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, device_map={"":0})
return model
def setup_model_and_tokenizer(args):
tokenizer = get_tokenizer(args)
model = get_model(args)
return tokenizer, model
def read_json_file(filename):
with open(filename, 'r') as infile:
data = json.load(infile)
return data
def format_one_action(action):
return f"- {action}\n"
def format_actions_list(actions):
actions_str = ""
for action in actions:
actions_str += format_one_action(action)
return actions_str
def preprocess_data(task, args):
with open(args.prompt_file, 'r') as file:
task_description = file.read().split('===')
input_str = f"## Website:\n{task['website_en']}\n\n## Domain:\n{task['domain_en']}\n\n## Sub-domain:\n{task['subdomain_en']}\n\n## Actions (Each line is one action):\n{format_actions_list(task['task_subintention'])}\n## Sub-intentions summarised from these actions:\n{format_actions_list(task['steps'])}"
query_inputs = f"{task_description[0]}\n{input_str}{task_description[1]}\n"
summary_str = task['task_description']
summary_str = summary_str[0].upper() + summary_str[1:] + "."
test_prompt = f"User: {query_inputs}\nAgent:"
return {"task": summary_str, "prompt": test_prompt}
def load_raw_dataset(data, args):
tasks = []
for d in tqdm(data):
processed_task = preprocess_data(d, args)
tasks.append(processed_task)
return tasks
def main_loop(args, test_dataset, tokenizer, model, sacrebleu, rouge, meteor, embedding_model, mark):
os.makedirs(args.model_name_or_path+"/results/", exist_ok=True)
global_sacrebleu, global_rouge1, global_rouge2, global_rougeL, global_rougeLsum, global_meteor, global_cosine, global_distance = [], [], [], [], [], [], [], []
for i, data in tqdm(enumerate(test_dataset)):
save_task_response_filename = args.model_name_or_path + f"/results/{mark}_{i}_insert_mistral.json"
if os.path.exists(save_task_response_filename):
with open(save_task_response_filename, 'r') as f:
save_dict = json.load(f)
else:
prompt = data["prompt"]
task = data["task"]
save_dict = {}
model_inputs = tokenizer([prompt], return_tensors="pt").to("cuda")
generated_ids = model.generate(**model_inputs,max_new_tokens=1024, do_sample=False, top_p= 0.95, repetition_penalty=1.2)
pred = tokenizer.batch_decode(generated_ids)[0]
response = pred.split("[SUMMARY]")[-1].replace('</s>','').strip()
rouge_calc = rouge.compute(predictions = [response], references=[[task]], use_aggregator=True)
sacrebleu_calc = sacrebleu.compute(predictions = [response], references=[[task]])
meteor_calc = meteor.compute(predictions = [response], references=[[task]])
GT_Embedding= embedding_model.encode(task.lower(), convert_to_tensor=True)
Prediction_Embedding = embedding_model.encode(response.lower(), convert_to_tensor=True)
cosine_similarity = util.cos_sim(GT_Embedding, Prediction_Embedding).item()
euclidean_disance = torch.sqrt(torch.sum(torch.pow(torch.subtract(GT_Embedding, Prediction_Embedding), 2))).item()
save_dict["prompt"] = prompt
save_dict["prediction"] = response
save_dict["task"] = task
save_dict["sacrebleu"] = sacrebleu_calc
save_dict["rouge"] = rouge_calc
save_dict["meteor"] = meteor_calc
save_dict["cosine_similarity"] = cosine_similarity
save_dict["euclidean_disance"] = euclidean_disance
with open(save_task_response_filename, 'w') as f:
json.dump(save_dict, f)
global_sacrebleu.append(save_dict["sacrebleu"]["score"])
global_rouge1.append(save_dict["rouge"]["rouge1"])
global_rouge2.append(save_dict["rouge"]["rouge2"])
global_rougeL.append(save_dict["rouge"]["rougeL"])
global_rougeLsum.append(save_dict["rouge"]["rougeLsum"])
global_meteor.append(save_dict["meteor"]["meteor"])
global_cosine.append(save_dict["cosine_similarity"])
global_distance.append(save_dict["euclidean_disance"])
return global_sacrebleu, global_rouge1, global_rouge2, global_rougeL, global_rougeLsum, global_meteor, global_cosine, global_distance
def main(mark):
args = initialize()
assert 'Mind2Web' in args.test_data_dir
tokenizer, model = setup_model_and_tokenizer(args)
sacrebleu = evaluate.load('sacrebleu', modeule_type = "metric")
rouge = evaluate.load('rouge', modeule_type = "metric")
meteor = evaluate.load('meteor', modeule_type = "metric")
embedding_model = SentenceTransformer(args.embedding_model_path, device="cuda")
test_folders_names = ["test_domain", "test_task", "test_website"]
for name in test_folders_names:
test_folder_path = Path(os.path.join(args.test_data_dir,name))
global_sacrebleu, global_rouge1, global_rouge2, global_rougeL, global_rougeLsum, global_meteor, global_cosine, global_distance = [], [], [], [], [], [], [], []
for json_file in test_folder_path.rglob('*_with_steps_insert_mistral.json'):
with json_file.open('r') as f:
data = json.load(f)
raw_tasks = load_raw_dataset(data, args)
sacrebleu_calc, rouge1_calc, rouge2_calc, rougeL_calc, rougeLsum_calc, meteor_calc, cosine_calc, distance_calc = main_loop(args, raw_tasks, tokenizer, model, sacrebleu, rouge, meteor, embedding_model, 'test_%s'%(name))
global_sacrebleu.extend(sacrebleu_calc)
global_rouge1.extend(rouge1_calc)
global_rouge2.extend(rouge2_calc)
global_rougeL.extend(rougeL_calc)
global_rougeLsum.extend(rougeLsum_calc)
global_meteor.extend(meteor_calc)
global_cosine.extend(cosine_calc)
global_distance.extend(distance_calc)
print(mark, name)
print("%.3f" % (np.mean(global_cosine)))
print("%.3f" % (np.mean(global_sacrebleu)/100.0))
print("%.3f" % (np.mean(global_rougeL)))
print("%.3f" % (np.mean(global_meteor)))
if __name__ == "__main__":
main('test')

30
inference/inference.sh Normal file
View file

@ -0,0 +1,30 @@
PROJECT_PATH="your-project-path"
EMBEDDING_MODEL_PATH="${PROJECT_PATH}/sentence-transformer/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/e4ce9877abf3edfe10b0d82785e83bdcb973e22e"
OPTS=""
OPTS+=" --embedding_model_path ${EMBEDDING_MODEL_PATH}"
OPTS+=" --test_data_dir ${PROJECT_PATH}/data/Mind2Web/test"
OPTS+=" --train_data_dir ${PROJECT_PATH}/data/Mind2Web/train/train_with_steps_insert_mistral.json"
OPTS+=" --prompt_file ${PROJECT_PATH}/prompts/summarisation/summarisation_prompt.txt"
MODEL_NAME_OR_PATH_BMT="${PROJECT_PATH}/ckpts/experiment/epoch_14"
MODEL_NAME_OR_PATH_HF="${MODEL_NAME_OR_PATH_BMT}-hf"
MODEL_NAME_OR_PATH_ORIGINAL_MISTRAL="${PROJECT_PATH}/Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24"
# Convert the model to HF format
if [ ! -f "${MODEL_NAME_OR_PATH_HF}/config.json" ]; then
CMD="python3 ${PROJECT_PATH}/hf_bmt/bmt_hf.py --in_path ${MODEL_NAME_OR_PATH_BMT} --output_path ${MODEL_NAME_OR_PATH_HF} --original_mistral_path ${MODEL_NAME_OR_PATH_ORIGINAL_MISTRAL}"
echo "-------BMT -> HF CMD is------"
echo "CMD: ${CMD}"
echo "-------BMT -> HF CMD end------"
eval ${CMD}
fi
OPTS+=" --model_name_or_path ${MODEL_NAME_OR_PATH_HF}"
CMD="python3 inference.py ${OPTS}"
echo "-------final CMD is------"
echo "${CMD}"
echo "-------final CMD end------"
eval ${CMD}

View file

@ -0,0 +1,135 @@
import os, pdb
import re
import json
from enum import Enum
from tqdm import tqdm
from bs4 import BeautifulSoup
def read_json_file(filename):
with open(filename, 'r') as infile:
data = json.load(infile)
return data
def convert_string(string_or_list):
# Add escaping symbols to English quotes in string
if isinstance(string_or_list, str):
return string_or_list.replace('"', '\\"')
elif isinstance(string_or_list, list):
return [convert_string(s) for s in string_or_list]
def is_visible(element):
bounding_box = element.get('bounding_box_rect')
return bounding_box != "-1,-1,-1,-1"
def clean_text(text):
cleaned_text = text.strip()
cleaned_text = cleaned_text.replace('\n', ' ').replace('\t', ' ')
cleaned_text = re.sub(' +', ' ', cleaned_text)
return cleaned_text
def find_semantic_info(element):
element_text = clean_text(element.get_text(strip=True))
if element_text:
return element_text
label = element.find_previous(lambda x: x.name == 'label' and is_visible(x))
if label:
label_text = clean_text(label.get_text(strip=True))
if label_text:
return label_text
return None
def action_discription(ui_element_name, ui_element_text, operation_type, value):
ret_en = ""
if operation_type == "TYPE":
if ui_element_text != "":
ret_en += f'Type text "{value}" into {ui_element_name} with text "{ui_element_text}" on it'
else:
ret_en += f'Type text "{value}" into {ui_element_name}'
elif operation_type == "SELECT":
if ui_element_text != "":
ret_en += f'Select "{value}" from {ui_element_name} with text "{ui_element_text}" on it'
else:
ret_en += f'Select "{value}" from {ui_element_name}.'
elif operation_type == "CLICK":
if ui_element_text != "":
ret_en += f'Click the {ui_element_name} element with text "{ui_element_text}" on it'
else:
ret_en += f'Click the {ui_element_name} element'
return ret_en
def process_one_task(task):
base_info = {
"website_en": task["website"],
"domain_en": task["domain"],
"subdomain_en": task["subdomain"],
"annotation_id":task["annotation_id"],
"task_description": task["confirmed_task"],
"action_reprs" : task["action_reprs"]
}
action_descriptions_en = []
for action_index, action in enumerate(task["actions"]):
action_repr = task["action_reprs"][action_index]
ui_element, _ = action_repr.split(" -> ")
assert ui_element.count("] ")==1
ui_element_name, ui_element_text = ui_element.split("] ")
ui_element_name = ui_element_name[1:]
ui_element_text = ui_element_text.strip()
if ui_element_text == "":
raw_html = action["raw_html"]
soup2 = BeautifulSoup(raw_html, 'html.parser')
selected_element2 = soup2.find(attrs={"data_pw_testid_buckeye": action["action_uid"]})
ui_element_text = find_semantic_info(selected_element2)
if ui_element_text is not None:
ui_element_text = clean_text(ui_element_text)
task["action_reprs"][action_index] = f"[{ui_element_name}] {ui_element_text} -> {task['action_reprs'][action_index].split(' -> ')[1]}"
else:
print(f'Warning: {task["annotation_id"]}, can not find semantic info for {action["action_uid"]}')
action_description_en = action_discription(ui_element_name, ui_element_text, action["operation"]["op"], action["operation"]["value"])
action_descriptions_en.append(action_description_en)
base_info["task_subintention"] = action_descriptions_en
return base_info
if __name__ == "__main__":
for foldername in ['train','test_domain','test_website','test_task']:
SAVE_PATH = f"your-path-to-data/{foldername}"
for idx in range(100):
savejsonfilename = os.path.join(SAVE_PATH,f'{foldername}_{idx}_with_actions_description_insert.json')
if os.path.exists(savejsonfilename):
continue
else:
jsonfilename = f"{SAVE_PATH}/{foldername}_{idx}.json"
if not os.path.exists(jsonfilename):
break
dataset = read_json_file(jsonfilename)
Mind2Web_with_subintentions = []
for task in tqdm(dataset):
base_info = process_one_task(task)
Mind2Web_with_subintentions.append(base_info)
assert len(Mind2Web_with_subintentions) == len(dataset)
if 'test' in foldername:
with open(os.path.join(SAVE_PATH,f'{foldername}_{idx}_with_actions_description.json'), 'r') as json_file:
Mind2Web_with_subintentions_saved = json.load(json_file)
for i in range(len(Mind2Web_with_subintentions)):
if i>=len(Mind2Web_with_subintentions_saved):
break
if Mind2Web_with_subintentions[i] != Mind2Web_with_subintentions_saved[i]:
for key in Mind2Web_with_subintentions[i].keys():
if Mind2Web_with_subintentions[i][key] != Mind2Web_with_subintentions_saved[i][key]:
found = False
for j in range(len(Mind2Web_with_subintentions_saved)):
if Mind2Web_with_subintentions[i][key] == Mind2Web_with_subintentions_saved[j][key]:
found = True
break
if not found:
print(found, i, j, jsonfilename)
with open(savejsonfilename, 'w') as json_file:
json.dump(Mind2Web_with_subintentions, json_file)

View file

@ -0,0 +1,97 @@
from tqdm import tqdm
import json
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
def get_tokenizer(model_name_or_path):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, device_map={"":0})
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
return tokenizer
def get_model(model_name_or_path):
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, device_map={"":0})
return model
def read_json_file(filename):
with open(filename, 'r') as infile:
data = json.load(infile)
return data
if __name__ == "__main__":
model_name_or_path = "Mistral-7B-v0.1/snapshots/26bca36bde8333b5d7f72e9ed20ccda6a618af24"
tokenizer = get_tokenizer(model_name_or_path)
model = get_model(model_name_or_path)
# load prompts
with open("your-path-to-data/train_prompt.txt", "r") as f:
train_prompt = f.read()
with open("your-path-to-data/test_prompt.txt", "r") as f:
test_prompt = f.read()
for foldername in ['train','test_domain','test_website','test_task']:
SAVE_PATH = f"your-path-to-data/{foldername}"
for idx in range(100):
savejsonfilename = f"{SAVE_PATH}/{foldername}_{idx}_with_steps_insert_mistral.json"
jsonfilename = f"{SAVE_PATH}/{foldername}_{idx}_with_actions_description_insert.json"
if not os.path.exists(jsonfilename):
break
data = read_json_file(jsonfilename)
if os.path.exists(savejsonfilename):
data = read_json_file(savejsonfilename)
actions_steps = []
for i in tqdm(range(len(data)), desc="Steps_Creation"):
if "train" in foldername: # include task
message = f"""Website: {data[i]["website_en"]}
Domain: {data[i]["domain_en"]}
Sub-domain: {data[i]["subdomain_en"]}
Task: {data[i]["task_description"]}
Actions: {data[i]["task_subintention"]}\n
# OUTPUT #
"""
prompt = train_prompt
else: # exclude task
message = f"""Website: {data[i]["website_en"]}
Domain: {data[i]["domain_en"]}
Sub-domain: {data[i]["subdomain_en"]}
Actions: {data[i]["task_subintention"]}\n
# OUTPUT #
"""
prompt = test_prompt
messages=[
{"role": "system", "content": prompt},
{"role": "user", "content": message}
]
messages = 'System: ' + prompt + 'User: ' + message
model_inputs = tokenizer(messages, return_tensors="pt").to("cuda")
assert len(model_inputs['input_ids'])<=4096
generated_ids = model.generate(**model_inputs,max_new_tokens=1024, do_sample=False, top_p= 0.95, repetition_penalty=1.2)
json_object = tokenizer.batch_decode(generated_ids)[0]
answer = json_object.split('Sub-intentions: [')[-1].split('\n')
final_answer = []
for a in answer:
a = a.strip()
if '</s>' in a:
a = a.split('</s>')[0]
if len(a)==0:
continue
while a[0]=='"':
a = a[1:]
if len(a)==0:
break
if len(a)==0:
continue
while a[-1] in ['"', ',', ']', ]:
a = a[:-1]
if len(a)==0:
break
if len(a)==0:
continue
final_answer.append(a)
data[i]['steps'] = final_answer
with open(savejsonfilename, 'w') as json_file:
json.dump(data, json_file)

374
train/train.py Normal file
View file

@ -0,0 +1,374 @@
import argparse
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
import bmtrain as bmt
from functools import partial
import time
import os, pdb, shutil
import random
import json
from model_center.model import Llama
from model_center.tokenizer import LlamaTokenizer
from functools import partial
from dataset_wrapper import PromptIterableDataset, collator
import wandb
import csv
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:32"
import logging
import numpy as np
import math
from sentence_transformers import SentenceTransformer, util
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def set_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def get_tokenizer(args):
tokenizer = LlamaTokenizer.from_pretrained(args.model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
return tokenizer
def get_model(args):
model = Llama.from_pretrained(args.model_name_or_path)
if args.load_ckpt is not None:
logger.info(f"loading model from {args.load_ckpt}")
bmt.load(model, os.path.join(args.load_ckpt, "pytorch_model.pt"))
return model
def get_optimizer(args, model):
optimizer = bmt.optim.AdamOffloadOptimizer(
model.parameters(),
weight_decay=args.weight_decay,
eps=1e-5,
betas=(0.9, 0.95)
)
if args.load_ckpt is not None:
file_name = os.path.join(args.load_ckpt, "optim.rank-{}.opt".format(bmt.rank()))
logger.info(file_name)
if os.path.exists(file_name):
logger.info("start to load gradient ckpt {}".format(file_name))
states = torch.load(file_name)
optimizer.load_state_dict(states)
return optimizer
def get_learning_rate_scheduler(args, optimizer):
if args.lr_decay_iters is None:
args.lr_decay_iters = args.train_iters
if args.lr_decay_style == "linear":
lr_scheduler = bmt.lr_scheduler.Linear(
optimizer,
start_lr=args.lr,
warmup_iter=int(args.warmup_iters),
end_iter=args.lr_decay_iters,
num_iter=args.start_step,
)
elif args.lr_decay_style == "cosine":
bmt.print_rank("use cosine")
class Cosine(bmt.lr_scheduler.WarmupLRScheduler):
def get_lr_warmup(self, num_iter) -> float:
return self.start_lr * num_iter / self.warmup_iter
def get_lr_decay(self, num_iter) -> float:
progress = (num_iter - self.warmup_iter) / max(1, (self.end_iter - self.warmup_iter))
return max(self.start_lr * 0.1, self.start_lr * (0.1 + 0.45 * (1.0 + math.cos(progress * math.pi))))
lr_scheduler = Cosine(
optimizer,
start_lr=args.lr,
warmup_iter=int(args.warmup_iters),
end_iter=args.lr_decay_iters,
num_iter=args.start_step,
)
elif args.lr_decay_style == "noam":
logger.info("use noam")
lr_scheduler = bmt.lr_scheduler.Noam(
optimizer,
start_lr=args.lr,
warmup_iter=int(args.warmup_iters),
end_iter=args.lr_decay_iters,
num_iter=args.start_step,
)
else:
raise NotImplementedError
return lr_scheduler
def setup_model_and_optimizer(args):
# get the tokenizer
tokenizer = get_tokenizer(args)
# get the model
model = get_model(args)
bmt.synchronize()
# get the optimizer and lr_scheduler
optimizer = get_optimizer(args, model)
lr_scheduler = get_learning_rate_scheduler(args, optimizer)
bmt.synchronize()
return tokenizer, model, optimizer, lr_scheduler
def initialize():
parser = argparse.ArgumentParser("")
# model training arguments
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--model_name_or_path")
parser.add_argument("--epochs", type=int, default=1)
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max_seq_length", default=2048, type=int)
parser.add_argument("--batch_size_per_device", default=2, type=int)
parser.add_argument("--logging_step", default=100, type=int)
parser.add_argument("--save_step", default=50000, type=int)
parser.add_argument("--gradient_accumulation_steps", default=1, type=int)
parser.add_argument("--wandb", default= True ,action="store_true")
parser.add_argument("--with_eval", action="store_true")
parser.add_argument("--clip_grad", type=float, default=1.0, help="gradient clipping")
parser.add_argument("--weight_decay", type=float, default=0.0, help="weight decay rate")
parser.add_argument("--loss_scale", type=float, default=6553600, help="loss scale")
parser.add_argument("--train_iters", type=int, default=2000000)
# loss parameters
parser.add_argument("--action_weight", type=float, help="weight of the tokens that match the action")
parser.add_argument("--embedding_model_path", type=str, help="The path to the sentence embedding model")
# data parameters
parser.add_argument('--data_setting', type=str ,help='MTSD or MTMD', default="MTMD")
parser.add_argument('--data_dir', type=str, help='The directory for saving the dataset')
parser.add_argument('--max_train_samples', type=int, help='The maximum number of training samples')
parser.add_argument('--cache_dir', type=str, help='The directory for cache')
parser.add_argument("--save_dir", type=str, default="")
parser.add_argument("--save_limit", type=int, default=None, help="ckpt saved limit number")
parser.add_argument("--warmup_iters", type=int, default=1000)
parser.add_argument(
"--lr_decay_style",
type=str,
default="cosine",
choices=["constant", "linear", "cosine", "exponential", "noam"],
help="learning rate decay function",
)
parser.add_argument("--lr_decay_iters", type=int, default=None, help="lr decay steps")
parser.add_argument("--start_step", type=int, default=0, help="step to start or continue training")
parser.add_argument("--load_ckpt", type=str, default=None, help="resumed ckpt")
parser.add_argument("--save_processed_data", action='store_true', help="wheather or no save the processed data")
parser.add_argument("--prompt_file", type=str, default=None, help="The file for loading the prompt")
args = parser.parse_args()
# init bmt
bmt.init_distributed(seed=args.seed)
set_seed(args.seed)
# wandb
if args.wandb and bmt.rank() == 0:
wandb.init(project='Mistral-Interact', config=args, name=args.save_dir.split('Mistral-7b/')[1][:-1], save_code=True, settings=wandb.Settings(code_dir="."))
return args
def format_one_action(action):
return f"- {action}\n"
def format_actions_list(actions):
actions_str = ""
for action in actions:
actions_str += format_one_action(action)
return actions_str
def read_json_file(filename):
with open(filename, 'r') as infile:
data = json.load(infile)
return data
def load_Mind2Web_dataset(args, save_dataset= False):
# read text from a file (file name is args.prompt_file)
with open(args.prompt_file, 'r') as file:
task_description = file.read().split('===')
raw_dataset = read_json_file(args.data_dir)
dataset=[]
for idx, d in enumerate(raw_dataset):
sequences = []
input_str = f"## Website:\n{d['website_en']}\n\n## Domain:\n{d['domain_en']}\n\n## Sub-domain:\n{d['subdomain_en']}\n\n## Actions (Each line is one action):\n{format_actions_list(d['task_subintention'])}\n## Sub-intentions summarised from these actions:\n{format_actions_list(d['steps'])}"
query_inputs = f"{task_description[0]}\n{input_str}{task_description[1]}\n"
sequences.append(query_inputs)
summary_str = d['task_description']
summary_str = "[SUMMARY] " + summary_str[0].upper() + summary_str[1:]
sequences.append(summary_str)
dataset.append({"data": sequences.copy()})
random.shuffle(dataset)
if args.max_train_samples is not None:
dataset = dataset[:args.max_train_samples]
return dataset
def load_MoTIF_dataset(args, save_dataset= False):
with open(args.prompt_file, 'r') as file:
task_description = file.read().split('===')
raw_dataset = []
for filename in os.listdir(args.data_dir):
if filename.endswith('_steps.json'):
file_path = os.path.join(args.data_dir, filename)
with open(file_path, 'r', encoding='utf-8') as json_file:
try:
content = json.load(json_file)
raw_dataset.append(content)
except json.JSONDecodeError as e:
raise ValueError(f"Error decoding JSON from file {filename}: {e}")
dataset=[]
for d in raw_dataset:
sequences = []
input_str = f"## Application:\n{d['app']}\n\n## Actions (Each line is one action):\n{format_actions_list(d['instr'])}\n## Sub-intentions summarised from these actions:\n{format_actions_list(d['steps'])}"
query_inputs = f"{task_description[0]}\n{input_str}{task_description[1]}\n"
sequences.append(query_inputs)
summary_str = d['goal']
summary_str = "[SUMMARY] " + summary_str[0].upper() + summary_str[1:]
sequences.append(summary_str)
dataset.append({"data": sequences.copy()})
random.shuffle(dataset)
if args.max_train_samples is not None:
dataset = dataset[:args.max_train_samples]
return dataset
def finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset):
embedding_model = SentenceTransformer(args.embedding_model_path, device="cuda")
for param in embedding_model.parameters():
param.requires_grad = False
logger.info(f"total training instance number: {len(dataset)}")
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, reduction="none")
optim_manager = bmt.optim.OptimManager(loss_scale=args.loss_scale)
optim_manager.add_optimizer(optimizer, lr_scheduler)
bmt.synchronize()
avg_time_recorder = bmt.utils.AverageRecorder()
avg_loss_recorder = bmt.utils.AverageRecorder()
train_start_time = time.time()
global_step = 0
logger.info("split data for each process")
data_per_gpu = len(dataset) // bmt.world_size()
dataset = dataset[bmt.rank() * data_per_gpu: (bmt.rank() + 1) * data_per_gpu]
bmt.print_rank("training on [%d, %d] of the dataset" % (bmt.rank() * data_per_gpu, (bmt.rank() + 1) * data_per_gpu))
dataset = PromptIterableDataset(
dataset,
tokenizer=tokenizer,
max_seq_length=args.max_seq_length,
teacher_forcing=True,
truncate_method="tail",
)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total trainable parameters: {total_params}")
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")
for epoch in range(args.epochs):
savefolder = os.path.join(args.save_dir, f"epoch_{epoch}")
os.makedirs(savefolder, exist_ok=True)
dataloader = DataLoader(dataset, batch_size=args.batch_size_per_device)
progress_bar = tqdm(range(len(dataloader)), disable=not bmt.rank()==0, desc=f"epoch {epoch}")
logger.info(f"*******start {epoch} epoch training********")
for step, inputs in enumerate(dataloader):
if global_step < args.start_step:
global_step += 1
progress_bar.update(1)
continue
st = time.time()
with bmt.inspect.inspect_tensor() as inspector:
for k in inputs:
inputs[k] = inputs[k].cuda()
labels = inputs.pop("labels")
weight_idxs = inputs.pop('weight_idxs')
logits = model(**inputs).logits
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
shift_logits = shift_logits.view(-1, len(tokenizer))
shift_labels = shift_labels.view(-1).to(shift_logits.device)
ntp_loss = loss_func(shift_logits, shift_labels)
sample_specific_weights = torch.ones_like(shift_logits)
weight_idxs = weight_idxs[:, 1:, :].contiguous()
weight_idxs = weight_idxs.view(-1, weight_idxs.size(-1))
assert weight_idxs.shape[0] == sample_specific_weights.shape[0], "310"
sample_specific_weights[weight_idxs==1] = args.action_weight
sample_specific_weights = sample_specific_weights[torch.arange(sample_specific_weights.size(0)), shift_labels]
ntp_loss = (ntp_loss * sample_specific_weights).mean()
next_token_loss_item = bmt.sum_loss(ntp_loss).item()
global_loss = next_token_loss_item
optim_manager.backward(ntp_loss)
if (step + 1) % args.gradient_accumulation_steps == 0 or step == len(dataloader) - 1:
optim_manager.clip_grad_norm(optimizer.param_groups, max_norm=args.clip_grad)
optim_manager.step()
optim_manager.zero_grad()
global_step += 1
progress_bar.update(1)
# record time and loss
iteration_time = time.time() - st
avg_time_recorder.record(iteration_time)
if not np.isnan(global_loss):
avg_loss_recorder.record(global_loss)
# print time and loss
if global_step % args.logging_step == 0:
bmt.print_rank(
"| Iter: {:6d} | loss: {:.4f} average_loss: {:.4f} | lr: {:.4e} | time: {:.4f} seconds | total_time_passed: {:.4f} minutes".format(
global_step,
global_loss,
avg_loss_recorder.value,
lr_scheduler.current_lr,
avg_time_recorder.value,
(time.time() - train_start_time) / 60
)
)
if args.wandb and bmt.rank() == 0:
wandb.log({
"loss": global_loss,
"next_token_loss": next_token_loss_item,
"average_loss": avg_loss_recorder.value,
"lr": lr_scheduler.current_lr,
}, step=global_step)
if global_step == args.train_iters:
break
bmt.save(model, os.path.join(savefolder, "pytorch_model.pt"))
if bmt.rank() == 0:
tokenizer.save_pretrained(savefolder)
bmt.print_rank(f"model saved at {savefolder}")
def main():
args = initialize()
if "Mind2Web" in args.data_dir:
dataset = load_Mind2Web_dataset(args, save_dataset=True)
else:
assert "MoTIF" in args.data_dir
dataset = load_MoTIF_dataset(args, save_dataset=True)
args.train_iters = min(args.epochs * (len(dataset) // (bmt.world_size() * args.batch_size_per_device) + 1), args.train_iters)
tokenizer, model, optimizer, lr_scheduler = setup_model_and_optimizer(args)
finetune(args, tokenizer, model, optimizer, lr_scheduler, dataset)
if __name__ == "__main__":
main()

44
train/train.sh Executable file
View file

@ -0,0 +1,44 @@
#! /bin/bash
MASTER_ADDR=localhost
MASTER_PORT=12345
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=2
DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
--nnodes $NNODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT"
PROJECT_PATH="your-project-path"
OPTS=""
# model config
MAXSEQLEN=1024
OPTS+=" --max_seq_length ${MAXSEQLEN}"
OPTS+=" --model_name_or_path ${PROJECT_PATH}/Mistral-7b-bmtrain"
# training config
OPTS+=" --logging_step 4"
BATCHSIZE=16
OPTS+=" --batch_size_per_device ${BATCHSIZE}"
OPTS+=" --save_step 500"
OPTS+=" --epochs 15"
LR=1e-6
OPTS+=" --lr ${LR}"
OPTS+=" --warmup_iters 0"
OPTS+=" --start_step 0"
OPTS+=" --loss_scale 6400"
ACTIONWEIGHT=2
OPTS+=" --action_weight ${ACTIONWEIGHT}"
EMBEDDING_MODEL_PATH="${PROJECT_PATH}/sentence-transformer/models--sentence-transformers--all-MiniLM-L6-v2/snapshots/e4ce9877abf3edfe10b0d82785e83bdcb973e22e"
OPTS+=" --embedding_model_path ${EMBEDDING_MODEL_PATH}"
OPTS+=" --prompt_file ${PROJECT_PATH}/prompts/summarisation/summarisation_prompt.txt"
OPTS+=" --save_dir ${PROJECT_PATH}/ckpts/experiment"
CMD="torchrun ${DISTRIBUTED_ARGS} train.py ${OPTS}"
echo "-------final CMD is------"
echo "${CMD}"
echo "-------final CMD end------"
${CMD}