up
This commit is contained in:
parent
08780752d9
commit
e15b0d7b50
46 changed files with 14927 additions and 0 deletions
339
LICENSE
Normal file
339
LICENSE
Normal file
|
@ -0,0 +1,339 @@
|
||||||
|
GNU GENERAL PUBLIC LICENSE
|
||||||
|
Version 2, June 1991
|
||||||
|
|
||||||
|
Copyright (C) 1989, 1991 Free Software Foundation, Inc.,
|
||||||
|
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
|
||||||
|
Everyone is permitted to copy and distribute verbatim copies
|
||||||
|
of this license document, but changing it is not allowed.
|
||||||
|
|
||||||
|
Preamble
|
||||||
|
|
||||||
|
The licenses for most software are designed to take away your
|
||||||
|
freedom to share and change it. By contrast, the GNU General Public
|
||||||
|
License is intended to guarantee your freedom to share and change free
|
||||||
|
software--to make sure the software is free for all its users. This
|
||||||
|
General Public License applies to most of the Free Software
|
||||||
|
Foundation's software and to any other program whose authors commit to
|
||||||
|
using it. (Some other Free Software Foundation software is covered by
|
||||||
|
the GNU Lesser General Public License instead.) You can apply it to
|
||||||
|
your programs, too.
|
||||||
|
|
||||||
|
When we speak of free software, we are referring to freedom, not
|
||||||
|
price. Our General Public Licenses are designed to make sure that you
|
||||||
|
have the freedom to distribute copies of free software (and charge for
|
||||||
|
this service if you wish), that you receive source code or can get it
|
||||||
|
if you want it, that you can change the software or use pieces of it
|
||||||
|
in new free programs; and that you know you can do these things.
|
||||||
|
|
||||||
|
To protect your rights, we need to make restrictions that forbid
|
||||||
|
anyone to deny you these rights or to ask you to surrender the rights.
|
||||||
|
These restrictions translate to certain responsibilities for you if you
|
||||||
|
distribute copies of the software, or if you modify it.
|
||||||
|
|
||||||
|
For example, if you distribute copies of such a program, whether
|
||||||
|
gratis or for a fee, you must give the recipients all the rights that
|
||||||
|
you have. You must make sure that they, too, receive or can get the
|
||||||
|
source code. And you must show them these terms so they know their
|
||||||
|
rights.
|
||||||
|
|
||||||
|
We protect your rights with two steps: (1) copyright the software, and
|
||||||
|
(2) offer you this license which gives you legal permission to copy,
|
||||||
|
distribute and/or modify the software.
|
||||||
|
|
||||||
|
Also, for each author's protection and ours, we want to make certain
|
||||||
|
that everyone understands that there is no warranty for this free
|
||||||
|
software. If the software is modified by someone else and passed on, we
|
||||||
|
want its recipients to know that what they have is not the original, so
|
||||||
|
that any problems introduced by others will not reflect on the original
|
||||||
|
authors' reputations.
|
||||||
|
|
||||||
|
Finally, any free program is threatened constantly by software
|
||||||
|
patents. We wish to avoid the danger that redistributors of a free
|
||||||
|
program will individually obtain patent licenses, in effect making the
|
||||||
|
program proprietary. To prevent this, we have made it clear that any
|
||||||
|
patent must be licensed for everyone's free use or not licensed at all.
|
||||||
|
|
||||||
|
The precise terms and conditions for copying, distribution and
|
||||||
|
modification follow.
|
||||||
|
|
||||||
|
GNU GENERAL PUBLIC LICENSE
|
||||||
|
TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION
|
||||||
|
|
||||||
|
0. This License applies to any program or other work which contains
|
||||||
|
a notice placed by the copyright holder saying it may be distributed
|
||||||
|
under the terms of this General Public License. The "Program", below,
|
||||||
|
refers to any such program or work, and a "work based on the Program"
|
||||||
|
means either the Program or any derivative work under copyright law:
|
||||||
|
that is to say, a work containing the Program or a portion of it,
|
||||||
|
either verbatim or with modifications and/or translated into another
|
||||||
|
language. (Hereinafter, translation is included without limitation in
|
||||||
|
the term "modification".) Each licensee is addressed as "you".
|
||||||
|
|
||||||
|
Activities other than copying, distribution and modification are not
|
||||||
|
covered by this License; they are outside its scope. The act of
|
||||||
|
running the Program is not restricted, and the output from the Program
|
||||||
|
is covered only if its contents constitute a work based on the
|
||||||
|
Program (independent of having been made by running the Program).
|
||||||
|
Whether that is true depends on what the Program does.
|
||||||
|
|
||||||
|
1. You may copy and distribute verbatim copies of the Program's
|
||||||
|
source code as you receive it, in any medium, provided that you
|
||||||
|
conspicuously and appropriately publish on each copy an appropriate
|
||||||
|
copyright notice and disclaimer of warranty; keep intact all the
|
||||||
|
notices that refer to this License and to the absence of any warranty;
|
||||||
|
and give any other recipients of the Program a copy of this License
|
||||||
|
along with the Program.
|
||||||
|
|
||||||
|
You may charge a fee for the physical act of transferring a copy, and
|
||||||
|
you may at your option offer warranty protection in exchange for a fee.
|
||||||
|
|
||||||
|
2. You may modify your copy or copies of the Program or any portion
|
||||||
|
of it, thus forming a work based on the Program, and copy and
|
||||||
|
distribute such modifications or work under the terms of Section 1
|
||||||
|
above, provided that you also meet all of these conditions:
|
||||||
|
|
||||||
|
a) You must cause the modified files to carry prominent notices
|
||||||
|
stating that you changed the files and the date of any change.
|
||||||
|
|
||||||
|
b) You must cause any work that you distribute or publish, that in
|
||||||
|
whole or in part contains or is derived from the Program or any
|
||||||
|
part thereof, to be licensed as a whole at no charge to all third
|
||||||
|
parties under the terms of this License.
|
||||||
|
|
||||||
|
c) If the modified program normally reads commands interactively
|
||||||
|
when run, you must cause it, when started running for such
|
||||||
|
interactive use in the most ordinary way, to print or display an
|
||||||
|
announcement including an appropriate copyright notice and a
|
||||||
|
notice that there is no warranty (or else, saying that you provide
|
||||||
|
a warranty) and that users may redistribute the program under
|
||||||
|
these conditions, and telling the user how to view a copy of this
|
||||||
|
License. (Exception: if the Program itself is interactive but
|
||||||
|
does not normally print such an announcement, your work based on
|
||||||
|
the Program is not required to print an announcement.)
|
||||||
|
|
||||||
|
These requirements apply to the modified work as a whole. If
|
||||||
|
identifiable sections of that work are not derived from the Program,
|
||||||
|
and can be reasonably considered independent and separate works in
|
||||||
|
themselves, then this License, and its terms, do not apply to those
|
||||||
|
sections when you distribute them as separate works. But when you
|
||||||
|
distribute the same sections as part of a whole which is a work based
|
||||||
|
on the Program, the distribution of the whole must be on the terms of
|
||||||
|
this License, whose permissions for other licensees extend to the
|
||||||
|
entire whole, and thus to each and every part regardless of who wrote it.
|
||||||
|
|
||||||
|
Thus, it is not the intent of this section to claim rights or contest
|
||||||
|
your rights to work written entirely by you; rather, the intent is to
|
||||||
|
exercise the right to control the distribution of derivative or
|
||||||
|
collective works based on the Program.
|
||||||
|
|
||||||
|
In addition, mere aggregation of another work not based on the Program
|
||||||
|
with the Program (or with a work based on the Program) on a volume of
|
||||||
|
a storage or distribution medium does not bring the other work under
|
||||||
|
the scope of this License.
|
||||||
|
|
||||||
|
3. You may copy and distribute the Program (or a work based on it,
|
||||||
|
under Section 2) in object code or executable form under the terms of
|
||||||
|
Sections 1 and 2 above provided that you also do one of the following:
|
||||||
|
|
||||||
|
a) Accompany it with the complete corresponding machine-readable
|
||||||
|
source code, which must be distributed under the terms of Sections
|
||||||
|
1 and 2 above on a medium customarily used for software interchange; or,
|
||||||
|
|
||||||
|
b) Accompany it with a written offer, valid for at least three
|
||||||
|
years, to give any third party, for a charge no more than your
|
||||||
|
cost of physically performing source distribution, a complete
|
||||||
|
machine-readable copy of the corresponding source code, to be
|
||||||
|
distributed under the terms of Sections 1 and 2 above on a medium
|
||||||
|
customarily used for software interchange; or,
|
||||||
|
|
||||||
|
c) Accompany it with the information you received as to the offer
|
||||||
|
to distribute corresponding source code. (This alternative is
|
||||||
|
allowed only for noncommercial distribution and only if you
|
||||||
|
received the program in object code or executable form with such
|
||||||
|
an offer, in accord with Subsection b above.)
|
||||||
|
|
||||||
|
The source code for a work means the preferred form of the work for
|
||||||
|
making modifications to it. For an executable work, complete source
|
||||||
|
code means all the source code for all modules it contains, plus any
|
||||||
|
associated interface definition files, plus the scripts used to
|
||||||
|
control compilation and installation of the executable. However, as a
|
||||||
|
special exception, the source code distributed need not include
|
||||||
|
anything that is normally distributed (in either source or binary
|
||||||
|
form) with the major components (compiler, kernel, and so on) of the
|
||||||
|
operating system on which the executable runs, unless that component
|
||||||
|
itself accompanies the executable.
|
||||||
|
|
||||||
|
If distribution of executable or object code is made by offering
|
||||||
|
access to copy from a designated place, then offering equivalent
|
||||||
|
access to copy the source code from the same place counts as
|
||||||
|
distribution of the source code, even though third parties are not
|
||||||
|
compelled to copy the source along with the object code.
|
||||||
|
|
||||||
|
4. You may not copy, modify, sublicense, or distribute the Program
|
||||||
|
except as expressly provided under this License. Any attempt
|
||||||
|
otherwise to copy, modify, sublicense or distribute the Program is
|
||||||
|
void, and will automatically terminate your rights under this License.
|
||||||
|
However, parties who have received copies, or rights, from you under
|
||||||
|
this License will not have their licenses terminated so long as such
|
||||||
|
parties remain in full compliance.
|
||||||
|
|
||||||
|
5. You are not required to accept this License, since you have not
|
||||||
|
signed it. However, nothing else grants you permission to modify or
|
||||||
|
distribute the Program or its derivative works. These actions are
|
||||||
|
prohibited by law if you do not accept this License. Therefore, by
|
||||||
|
modifying or distributing the Program (or any work based on the
|
||||||
|
Program), you indicate your acceptance of this License to do so, and
|
||||||
|
all its terms and conditions for copying, distributing or modifying
|
||||||
|
the Program or works based on it.
|
||||||
|
|
||||||
|
6. Each time you redistribute the Program (or any work based on the
|
||||||
|
Program), the recipient automatically receives a license from the
|
||||||
|
original licensor to copy, distribute or modify the Program subject to
|
||||||
|
these terms and conditions. You may not impose any further
|
||||||
|
restrictions on the recipients' exercise of the rights granted herein.
|
||||||
|
You are not responsible for enforcing compliance by third parties to
|
||||||
|
this License.
|
||||||
|
|
||||||
|
7. If, as a consequence of a court judgment or allegation of patent
|
||||||
|
infringement or for any other reason (not limited to patent issues),
|
||||||
|
conditions are imposed on you (whether by court order, agreement or
|
||||||
|
otherwise) that contradict the conditions of this License, they do not
|
||||||
|
excuse you from the conditions of this License. If you cannot
|
||||||
|
distribute so as to satisfy simultaneously your obligations under this
|
||||||
|
License and any other pertinent obligations, then as a consequence you
|
||||||
|
may not distribute the Program at all. For example, if a patent
|
||||||
|
license would not permit royalty-free redistribution of the Program by
|
||||||
|
all those who receive copies directly or indirectly through you, then
|
||||||
|
the only way you could satisfy both it and this License would be to
|
||||||
|
refrain entirely from distribution of the Program.
|
||||||
|
|
||||||
|
If any portion of this section is held invalid or unenforceable under
|
||||||
|
any particular circumstance, the balance of the section is intended to
|
||||||
|
apply and the section as a whole is intended to apply in other
|
||||||
|
circumstances.
|
||||||
|
|
||||||
|
It is not the purpose of this section to induce you to infringe any
|
||||||
|
patents or other property right claims or to contest validity of any
|
||||||
|
such claims; this section has the sole purpose of protecting the
|
||||||
|
integrity of the free software distribution system, which is
|
||||||
|
implemented by public license practices. Many people have made
|
||||||
|
generous contributions to the wide range of software distributed
|
||||||
|
through that system in reliance on consistent application of that
|
||||||
|
system; it is up to the author/donor to decide if he or she is willing
|
||||||
|
to distribute software through any other system and a licensee cannot
|
||||||
|
impose that choice.
|
||||||
|
|
||||||
|
This section is intended to make thoroughly clear what is believed to
|
||||||
|
be a consequence of the rest of this License.
|
||||||
|
|
||||||
|
8. If the distribution and/or use of the Program is restricted in
|
||||||
|
certain countries either by patents or by copyrighted interfaces, the
|
||||||
|
original copyright holder who places the Program under this License
|
||||||
|
may add an explicit geographical distribution limitation excluding
|
||||||
|
those countries, so that distribution is permitted only in or among
|
||||||
|
countries not thus excluded. In such case, this License incorporates
|
||||||
|
the limitation as if written in the body of this License.
|
||||||
|
|
||||||
|
9. The Free Software Foundation may publish revised and/or new versions
|
||||||
|
of the General Public License from time to time. Such new versions will
|
||||||
|
be similar in spirit to the present version, but may differ in detail to
|
||||||
|
address new problems or concerns.
|
||||||
|
|
||||||
|
Each version is given a distinguishing version number. If the Program
|
||||||
|
specifies a version number of this License which applies to it and "any
|
||||||
|
later version", you have the option of following the terms and conditions
|
||||||
|
either of that version or of any later version published by the Free
|
||||||
|
Software Foundation. If the Program does not specify a version number of
|
||||||
|
this License, you may choose any version ever published by the Free Software
|
||||||
|
Foundation.
|
||||||
|
|
||||||
|
10. If you wish to incorporate parts of the Program into other free
|
||||||
|
programs whose distribution conditions are different, write to the author
|
||||||
|
to ask for permission. For software which is copyrighted by the Free
|
||||||
|
Software Foundation, write to the Free Software Foundation; we sometimes
|
||||||
|
make exceptions for this. Our decision will be guided by the two goals
|
||||||
|
of preserving the free status of all derivatives of our free software and
|
||||||
|
of promoting the sharing and reuse of software generally.
|
||||||
|
|
||||||
|
NO WARRANTY
|
||||||
|
|
||||||
|
11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY
|
||||||
|
FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN
|
||||||
|
OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES
|
||||||
|
PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED
|
||||||
|
OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
|
||||||
|
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS
|
||||||
|
TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE
|
||||||
|
PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING,
|
||||||
|
REPAIR OR CORRECTION.
|
||||||
|
|
||||||
|
12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||||
|
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR
|
||||||
|
REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES,
|
||||||
|
INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING
|
||||||
|
OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED
|
||||||
|
TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY
|
||||||
|
YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
|
||||||
|
PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
|
||||||
|
POSSIBILITY OF SUCH DAMAGES.
|
||||||
|
|
||||||
|
END OF TERMS AND CONDITIONS
|
||||||
|
|
||||||
|
How to Apply These Terms to Your New Programs
|
||||||
|
|
||||||
|
If you develop a new program, and you want it to be of the greatest
|
||||||
|
possible use to the public, the best way to achieve this is to make it
|
||||||
|
free software which everyone can redistribute and change under these terms.
|
||||||
|
|
||||||
|
To do so, attach the following notices to the program. It is safest
|
||||||
|
to attach them to the start of each source file to most effectively
|
||||||
|
convey the exclusion of warranty; and each file should have at least
|
||||||
|
the "copyright" line and a pointer to where the full notice is found.
|
||||||
|
|
||||||
|
<one line to give the program's name and a brief idea of what it does.>
|
||||||
|
Copyright (C) <year> <name of author>
|
||||||
|
|
||||||
|
This program is free software; you can redistribute it and/or modify
|
||||||
|
it under the terms of the GNU General Public License as published by
|
||||||
|
the Free Software Foundation; either version 2 of the License, or
|
||||||
|
(at your option) any later version.
|
||||||
|
|
||||||
|
This program is distributed in the hope that it will be useful,
|
||||||
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
GNU General Public License for more details.
|
||||||
|
|
||||||
|
You should have received a copy of the GNU General Public License along
|
||||||
|
with this program; if not, write to the Free Software Foundation, Inc.,
|
||||||
|
51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
|
||||||
|
|
||||||
|
Also add information on how to contact you by electronic and paper mail.
|
||||||
|
|
||||||
|
If the program is interactive, make it output a short notice like this
|
||||||
|
when it starts in an interactive mode:
|
||||||
|
|
||||||
|
Gnomovision version 69, Copyright (C) year name of author
|
||||||
|
Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
|
||||||
|
This is free software, and you are welcome to redistribute it
|
||||||
|
under certain conditions; type `show c' for details.
|
||||||
|
|
||||||
|
The hypothetical commands `show w' and `show c' should show the appropriate
|
||||||
|
parts of the General Public License. Of course, the commands you use may
|
||||||
|
be called something other than `show w' and `show c'; they could even be
|
||||||
|
mouse-clicks or menu items--whatever suits your program.
|
||||||
|
|
||||||
|
You should also get your employer (if you work as a programmer) or your
|
||||||
|
school, if any, to sign a "copyright disclaimer" for the program, if
|
||||||
|
necessary. Here is a sample; alter the names:
|
||||||
|
|
||||||
|
Yoyodyne, Inc., hereby disclaims all copyright interest in the program
|
||||||
|
`Gnomovision' (which makes passes at compilers) written by James Hacker.
|
||||||
|
|
||||||
|
<signature of Ty Coon>, 1 April 1989
|
||||||
|
Ty Coon, President of Vice
|
||||||
|
|
||||||
|
This General Public License does not permit incorporating your program into
|
||||||
|
proprietary programs. If your program is a subroutine library, you may
|
||||||
|
consider it more useful to permit linking proprietary applications with the
|
||||||
|
library. If this is what you want to do, use the GNU Lesser General
|
||||||
|
Public License instead of this License.
|
0
__init__.py
Executable file
0
__init__.py
Executable file
296
baselines_with_dialogue_moves.py
Executable file
296
baselines_with_dialogue_moves.py
Executable file
|
@ -0,0 +1,296 @@
|
||||||
|
from glob import glob
|
||||||
|
import os, json, sys
|
||||||
|
import torch, random, torch.nn as nn, numpy as np
|
||||||
|
from torch import optim
|
||||||
|
from random import shuffle
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score
|
||||||
|
from src.data.game_parser import GameParser, make_splits, onehot, DEVICE, set_seed
|
||||||
|
from src.models.model_with_dialogue_moves import Model
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
def print_epoch(data,acc_loss,lst):
|
||||||
|
print(f'{acc_loss/len(lst):9.4f}',end='; ',flush=True)
|
||||||
|
data = list(zip(*data))
|
||||||
|
for x in data:
|
||||||
|
a, b = list(zip(*x))
|
||||||
|
if max(a) <= 1:
|
||||||
|
print(f'({accuracy_score(a,b):5.3f},{f1_score(a,b,average="weighted"):5.3f},{sum(a)/len(a):5.3f},{sum(b)/len(b):5.3f},{len(b)})', end=' ',flush=True)
|
||||||
|
else:
|
||||||
|
print(f'({accuracy_score(a,b):5.3f},{f1_score(a,b,average="weighted"):5.3f},{len(b)})', end=' ',flush=True)
|
||||||
|
print('', end='; ',flush=True)
|
||||||
|
|
||||||
|
def do_split(model,lst,exp,criterion,optimizer=None,global_plan=False, player_plan=False,device=DEVICE):
|
||||||
|
data = []
|
||||||
|
acc_loss = 0
|
||||||
|
for game in lst:
|
||||||
|
|
||||||
|
if model.training and (not optimizer is None): optimizer.zero_grad()
|
||||||
|
|
||||||
|
l = model(game, global_plan=global_plan, player_plan=player_plan)
|
||||||
|
prediction = []
|
||||||
|
ground_truth = []
|
||||||
|
for gt, prd in l:
|
||||||
|
lbls = [int(a==b) for a,b in zip(gt[0],gt[1])]
|
||||||
|
lbls += [['NO', 'MAYBE', 'YES'].index(gt[0][0]),['NO', 'MAYBE', 'YES'].index(gt[0][1])]
|
||||||
|
if gt[0][2] in game.materials_dict:
|
||||||
|
lbls.append(game.materials_dict[gt[0][2]])
|
||||||
|
else:
|
||||||
|
lbls.append(0)
|
||||||
|
lbls += [['NO', 'MAYBE', 'YES'].index(gt[1][0]),['NO', 'MAYBE', 'YES'].index(gt[1][1])]
|
||||||
|
if gt[1][2] in game.materials_dict:
|
||||||
|
lbls.append(game.materials_dict[gt[1][2]])
|
||||||
|
else:
|
||||||
|
lbls.append(0)
|
||||||
|
prd = prd[exp:exp+1]
|
||||||
|
lbls = lbls[exp:exp+1]
|
||||||
|
data.append([(g,torch.argmax(p).item()) for p,g in zip(prd,lbls)])
|
||||||
|
# p, g = zip(*[(p,torch.eye(p.shape[0]).float()[g]) for p,g in zip(prd,lbls)])
|
||||||
|
if exp == 0:
|
||||||
|
pairs = list(zip(*[(pr,gt) for pr,gt in zip(prd,lbls) if gt==0 or (random.random() < 2/3)]))
|
||||||
|
elif exp == 1:
|
||||||
|
pairs = list(zip(*[(pr,gt) for pr,gt in zip(prd,lbls) if gt==0 or (random.random() < 5/6)]))
|
||||||
|
elif exp == 2:
|
||||||
|
pairs = list(zip(*[(pr,gt) for pr,gt in zip(prd,lbls) if gt==1 or (random.random() < 5/6)]))
|
||||||
|
else:
|
||||||
|
pairs = list(zip(*[(pr,gt) for pr,gt in zip(prd,lbls)]))
|
||||||
|
# print(pairs)
|
||||||
|
if pairs:
|
||||||
|
p,g = pairs
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
# print(p,g)
|
||||||
|
prediction.append(torch.cat(p))
|
||||||
|
|
||||||
|
# ground_truth.append(torch.cat(g))
|
||||||
|
ground_truth += g
|
||||||
|
|
||||||
|
if prediction:
|
||||||
|
prediction = torch.stack(prediction)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if ground_truth:
|
||||||
|
# ground_truth = torch.stack(ground_truth).float().to(DEVICE)
|
||||||
|
ground_truth = torch.tensor(ground_truth).long().to(device)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
|
loss = criterion(prediction,ground_truth)
|
||||||
|
|
||||||
|
if model.training and (not optimizer is None):
|
||||||
|
loss.backward()
|
||||||
|
# nn.utils.clip_grad_norm_(model.parameters(), 10)
|
||||||
|
nn.utils.clip_grad_norm_(model.parameters(), 1)
|
||||||
|
optimizer.step()
|
||||||
|
acc_loss += loss.item()
|
||||||
|
# return data, acc_loss + loss.item()
|
||||||
|
print_epoch(data,acc_loss,lst)
|
||||||
|
return acc_loss, data
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args, flush=True)
|
||||||
|
print(f'PID: {os.getpid():6d}', flush=True)
|
||||||
|
|
||||||
|
if isinstance(args.device, int) and args.device >= 0:
|
||||||
|
DEVICE = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
|
||||||
|
print(f'Using {DEVICE}')
|
||||||
|
else:
|
||||||
|
print('Device must be a zero or positive integer, but got',args.device)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if isinstance(args.seed, int) and args.seed >= 0:
|
||||||
|
seed = set_seed(args.seed)
|
||||||
|
else:
|
||||||
|
print('Seed must be a zero or positive integer, but got',args.seed)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# dataset_splits = make_splits('config/dataset_splits.json')
|
||||||
|
# dataset_splits = make_splits('config/dataset_splits_dev.json')
|
||||||
|
# dataset_splits = make_splits('config/dataset_splits_old.json')
|
||||||
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
||||||
|
|
||||||
|
if args.use_dialogue=='Yes':
|
||||||
|
d_flag = True
|
||||||
|
elif args.use_dialogue=='No':
|
||||||
|
d_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.use_dialogue_moves=='Yes':
|
||||||
|
d_move_flag = True
|
||||||
|
elif args.use_dialogue_moves=='No':
|
||||||
|
d_move_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not args.experiment in list(range(9)):
|
||||||
|
print('Experiment must be in',list(range(9)),', but got',args.experiment)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
if args.seq_model=='GRU':
|
||||||
|
seq_model = 0
|
||||||
|
elif args.seq_model=='LSTM':
|
||||||
|
seq_model = 1
|
||||||
|
elif args.seq_model=='Transformer':
|
||||||
|
seq_model = 2
|
||||||
|
else:
|
||||||
|
print('The sequence model must be in [GRU, LSTM, Transformer], but got', args.seq_model)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.plans=='Yes':
|
||||||
|
global_plan = (args.pov=='Third') or ((args.pov=='None') and (args.experiment in list(range(3))))
|
||||||
|
player_plan = (args.pov=='First') or ((args.pov=='None') and (args.experiment in list(range(3,9))))
|
||||||
|
elif args.plans=='No' or args.plans is None:
|
||||||
|
global_plan = False
|
||||||
|
player_plan = False
|
||||||
|
else:
|
||||||
|
print('Use Plan must be in [Yes, No], but got',args.plan)
|
||||||
|
exit()
|
||||||
|
print('global_plan', global_plan, 'player_plan', player_plan)
|
||||||
|
|
||||||
|
if args.pov=='None':
|
||||||
|
val = [GameParser(f,d_flag,0,0,d_move_flag) for f in dataset_splits['validation']]
|
||||||
|
train = [GameParser(f,d_flag,0,0,d_move_flag) for f in dataset_splits['training']]
|
||||||
|
if args.experiment > 2:
|
||||||
|
val += [GameParser(f,d_flag,4,0,d_move_flag) for f in dataset_splits['validation']]
|
||||||
|
train += [GameParser(f,d_flag,4,0,d_move_flag) for f in dataset_splits['training']]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
val = [GameParser(f,d_flag,3,0,d_move_flag) for f in dataset_splits['validation']]
|
||||||
|
train = [GameParser(f,d_flag,3,0,d_move_flag) for f in dataset_splits['training']]
|
||||||
|
elif args.pov=='First':
|
||||||
|
val = [GameParser(f,d_flag,1,0,d_move_flag) for f in dataset_splits['validation']]
|
||||||
|
train = [GameParser(f,d_flag,1,0,d_move_flag) for f in dataset_splits['training']]
|
||||||
|
val += [GameParser(f,d_flag,2,0,d_move_flag) for f in dataset_splits['validation']]
|
||||||
|
train += [GameParser(f,d_flag,2,0,d_move_flag) for f in dataset_splits['training']]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
model = Model(seq_model,DEVICE).to(DEVICE)
|
||||||
|
|
||||||
|
print(model)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
learning_rate = 1e-4
|
||||||
|
num_epochs = 1000#2#1#
|
||||||
|
weight_decay=1e-4
|
||||||
|
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
# optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
# optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
|
||||||
|
# optimizer = optim.Adagrad(model.parameters(), lr=learning_rate)
|
||||||
|
# optimizer = optim.Adadelta(model.parameters())
|
||||||
|
# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
# criterion = nn.MSELoss()
|
||||||
|
|
||||||
|
print(str(criterion), str(optimizer))
|
||||||
|
|
||||||
|
min_acc_loss = 100
|
||||||
|
max_f1 = 0
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
wait_epoch = 100
|
||||||
|
|
||||||
|
if args.model_path is not None:
|
||||||
|
print(f'Loading {args.model_path}')
|
||||||
|
model.load_state_dict(torch.load(args.model_path))
|
||||||
|
acc_loss, data = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, device=DEVICE)
|
||||||
|
data = list(zip(*data))
|
||||||
|
for x in data:
|
||||||
|
a, b = list(zip(*x))
|
||||||
|
f1 = f1_score(a,b,average='weighted')
|
||||||
|
f1 = f1_score(a,b,average='weighted')
|
||||||
|
if (max_f1 < f1):
|
||||||
|
max_f1 = f1
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
print('^')
|
||||||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
else:
|
||||||
|
print('Training model from scratch', flush=True)
|
||||||
|
# exit()
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
print(f'{os.getpid():6d} {epoch+1:4d},',end=' ', flush=True)
|
||||||
|
shuffle(train)
|
||||||
|
model.train()
|
||||||
|
do_split(model,train,args.experiment,criterion,optimizer=optimizer,global_plan=global_plan, player_plan=player_plan, device=DEVICE)
|
||||||
|
model.eval()
|
||||||
|
acc_loss, data = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, device=DEVICE)
|
||||||
|
|
||||||
|
data = list(zip(*data))
|
||||||
|
for x in data:
|
||||||
|
a, b = list(zip(*x))
|
||||||
|
f1 = f1_score(a,b,average='weighted')
|
||||||
|
if (max_f1 < f1):
|
||||||
|
max_f1 = f1
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
print('^')
|
||||||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
else:
|
||||||
|
epochs_since_improvement += 1
|
||||||
|
print()
|
||||||
|
# if (min_acc_loss > acc_loss):
|
||||||
|
# min_acc_loss = acc_loss
|
||||||
|
# epochs_since_improvement = 0
|
||||||
|
# print('^')
|
||||||
|
# else:
|
||||||
|
# epochs_since_improvement += 1
|
||||||
|
# print()
|
||||||
|
|
||||||
|
if epoch > wait_epoch and epochs_since_improvement > 20:
|
||||||
|
break
|
||||||
|
print()
|
||||||
|
print('Test')
|
||||||
|
model.load_state_dict(torch.load(args.save_path))
|
||||||
|
|
||||||
|
val = None
|
||||||
|
train = None
|
||||||
|
if args.pov=='None':
|
||||||
|
test = [GameParser(f,d_flag,0,0,d_move_flag) for f in dataset_splits['test']]
|
||||||
|
if args.experiment > 2:
|
||||||
|
test += [GameParser(f,d_flag,4,0,d_move_flag) for f in dataset_splits['test']]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
test = [GameParser(f,d_flag,3,0,d_move_flag) for f in dataset_splits['test']]
|
||||||
|
elif args.pov=='First':
|
||||||
|
test = [GameParser(f,d_flag,1,0,d_move_flag) for f in dataset_splits['test']]
|
||||||
|
test += [GameParser(f,d_flag,2,0,d_move_flag) for f in dataset_splits['test']]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
_, data = do_split(model,test,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, device=DEVICE)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(data)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||||
|
parser.add_argument('--pov', type=str,
|
||||||
|
help='point of view [None, First, Third]')
|
||||||
|
parser.add_argument('--use_dialogue', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--use_dialogue_moves', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--plans', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--seq_model', type=str,
|
||||||
|
help='point of view [GRU, LSTM, Transformer]')
|
||||||
|
parser.add_argument('--experiment', type=int,
|
||||||
|
help='point of view [0:AggQ1, 1:AggQ2, 2:AggQ3, 3:P0Q1, 4:P0Q2, 5:P0Q3, 6:P1Q1, 7:P1Q2, 8:P1Q3]')
|
||||||
|
parser.add_argument('--seed', type=int,
|
||||||
|
help='Selet random seed by index [0, 1, 2, ...]. 0 -> random seed set to 0. n>0 -> random seed '
|
||||||
|
'set to n\'th random number with original seed set to 0')
|
||||||
|
parser.add_argument('--save_path', type=str,
|
||||||
|
help='path where to save model')
|
||||||
|
parser.add_argument('--model_path', type=str, default=None,
|
||||||
|
help='path to the pretrained model to be loaded')
|
||||||
|
parser.add_argument('--device', type=int, default=0,
|
||||||
|
help='select cuda device number')
|
||||||
|
|
||||||
|
main(parser.parse_args())
|
106
baselines_with_dialogue_moves.sh
Executable file
106
baselines_with_dialogue_moves.sh
Executable file
|
@ -0,0 +1,106 @@
|
||||||
|
|
||||||
|
FOLDER="models/tom_lstm_baseline"
|
||||||
|
mkdir -p $FOLDER
|
||||||
|
CUDA_DEVICE=$1
|
||||||
|
SEED=$2
|
||||||
|
|
||||||
|
for MODEL in LSTM; do # LSTM; do # LSTM Transformer; do #
|
||||||
|
for DLGM in Yes; do # No; do # Yes; do # No
|
||||||
|
for EXP in 6 7 8; do # 2 3; do
|
||||||
|
|
||||||
|
DLG="No"
|
||||||
|
POV="None"
|
||||||
|
|
||||||
|
FILE_NAME="gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_${DLG}_pov_${POV}_exp${EXP}_seed_${SEED}"
|
||||||
|
COMM="baselines_with_dialogue_moves.py"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DLGM}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --pov=${POV}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T) $COMM" > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
|
||||||
|
DLG="Yes"
|
||||||
|
POV="None"
|
||||||
|
|
||||||
|
FILE_NAME="gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_${DLG}_pov_${POV}_exp${EXP}_seed_${SEED}"
|
||||||
|
COMM="baselines_with_dialogue_moves.py"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DLGM}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --pov=${POV}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_No_pov_None_exp${EXP}_seed_${SEED}.torch"
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T) $COMM" > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
|
||||||
|
DLG="No"
|
||||||
|
POV="First"
|
||||||
|
|
||||||
|
FILE_NAME="gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_${DLG}_pov_${POV}_exp${EXP}_seed_${SEED}"
|
||||||
|
COMM="baselines_with_dialogue_moves.py"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DLGM}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --pov=${POV}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_No_pov_None_exp${EXP}_seed_${SEED}.torch"
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T) $COMM" > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
|
||||||
|
DLG="Yes"
|
||||||
|
POV="First"
|
||||||
|
|
||||||
|
FILE_NAME="gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_${DLG}_pov_${POV}_exp${EXP}_seed_${SEED}_DlgFirst"
|
||||||
|
COMM="baselines_with_dialogue_moves.py"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DLGM}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --pov=${POV}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_Yes_pov_None_exp${EXP}_seed_${SEED}.torch"
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T) $COMM" > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
|
||||||
|
DLG="Yes"
|
||||||
|
POV="First"
|
||||||
|
|
||||||
|
FILE_NAME="gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_${DLG}_pov_${POV}_exp${EXP}_seed_${SEED}_VidFirst"
|
||||||
|
COMM="baselines_with_dialogue_moves.py"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DLGM}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --pov=${POV}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_No_pov_First_exp${EXP}_seed_${SEED}.torch"
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T) $COMM" > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
echo "Done!"
|
||||||
|
|
296
baselines_with_dialogue_moves_graphs.py
Normal file
296
baselines_with_dialogue_moves_graphs.py
Normal file
|
@ -0,0 +1,296 @@
|
||||||
|
import os
|
||||||
|
import torch, random, torch.nn as nn, numpy as np
|
||||||
|
from torch import optim
|
||||||
|
from random import shuffle
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score
|
||||||
|
from src.data.game_parser_graphs_new import GameParser, make_splits, DEVICE, set_seed
|
||||||
|
from src.models.model_with_dialogue_moves_graphs import Model
|
||||||
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
def print_epoch(data,acc_loss,lst):
|
||||||
|
print(f'{acc_loss/len(lst):9.4f}',end='; ',flush=True)
|
||||||
|
data = list(zip(*data))
|
||||||
|
for x in data:
|
||||||
|
a, b = list(zip(*x))
|
||||||
|
if max(a) <= 1:
|
||||||
|
print(f'({accuracy_score(a,b):5.3f},{f1_score(a,b,average="weighted"):5.3f},{sum(a)/len(a):5.3f},{sum(b)/len(b):5.3f},{len(b)})', end=' ',flush=True)
|
||||||
|
else:
|
||||||
|
print(f'({accuracy_score(a,b):5.3f},{f1_score(a,b,average="weighted"):5.3f},{len(b)})', end=' ',flush=True)
|
||||||
|
print('', end='; ',flush=True)
|
||||||
|
|
||||||
|
def do_split(model,lst,exp,criterion,optimizer=None,global_plan=False, player_plan=False,device=DEVICE):
|
||||||
|
data = []
|
||||||
|
acc_loss = 0
|
||||||
|
for game in lst:
|
||||||
|
|
||||||
|
if model.training and (not optimizer is None): optimizer.zero_grad()
|
||||||
|
|
||||||
|
l = model(game, global_plan=global_plan, player_plan=player_plan)
|
||||||
|
prediction = []
|
||||||
|
ground_truth = []
|
||||||
|
for gt, prd in l:
|
||||||
|
lbls = [int(a==b) for a,b in zip(gt[0],gt[1])]
|
||||||
|
lbls += [['NO', 'MAYBE', 'YES'].index(gt[0][0]),['NO', 'MAYBE', 'YES'].index(gt[0][1])]
|
||||||
|
if gt[0][2] in game.materials_dict:
|
||||||
|
lbls.append(game.materials_dict[gt[0][2]])
|
||||||
|
else:
|
||||||
|
lbls.append(0)
|
||||||
|
lbls += [['NO', 'MAYBE', 'YES'].index(gt[1][0]),['NO', 'MAYBE', 'YES'].index(gt[1][1])]
|
||||||
|
if gt[1][2] in game.materials_dict:
|
||||||
|
lbls.append(game.materials_dict[gt[1][2]])
|
||||||
|
else:
|
||||||
|
lbls.append(0)
|
||||||
|
prd = prd[exp:exp+1]
|
||||||
|
lbls = lbls[exp:exp+1]
|
||||||
|
data.append([(g,torch.argmax(p).item()) for p,g in zip(prd,lbls)])
|
||||||
|
# p, g = zip(*[(p,torch.eye(p.shape[0]).float()[g]) for p,g in zip(prd,lbls)])
|
||||||
|
if exp == 0:
|
||||||
|
pairs = list(zip(*[(pr,gt) for pr,gt in zip(prd,lbls) if gt==0 or (random.random() < 2/3)]))
|
||||||
|
elif exp == 1:
|
||||||
|
pairs = list(zip(*[(pr,gt) for pr,gt in zip(prd,lbls) if gt==0 or (random.random() < 5/6)]))
|
||||||
|
elif exp == 2:
|
||||||
|
pairs = list(zip(*[(pr,gt) for pr,gt in zip(prd,lbls) if gt==1 or (random.random() < 5/6)]))
|
||||||
|
else:
|
||||||
|
pairs = list(zip(*[(pr,gt) for pr,gt in zip(prd,lbls)]))
|
||||||
|
# print(pairs)
|
||||||
|
if pairs:
|
||||||
|
p,g = pairs
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
# print(p,g)
|
||||||
|
prediction.append(torch.cat(p))
|
||||||
|
|
||||||
|
# ground_truth.append(torch.cat(g))
|
||||||
|
ground_truth += g
|
||||||
|
|
||||||
|
if prediction:
|
||||||
|
prediction = torch.stack(prediction)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
if ground_truth:
|
||||||
|
# ground_truth = torch.stack(ground_truth).float().to(DEVICE)
|
||||||
|
ground_truth = torch.tensor(ground_truth).long().to(device)
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
loss = criterion(prediction,ground_truth)
|
||||||
|
|
||||||
|
if model.training and (not optimizer is None):
|
||||||
|
loss.backward()
|
||||||
|
# nn.utils.clip_grad_norm_(model.parameters(), 10)
|
||||||
|
nn.utils.clip_grad_norm_(model.parameters(), 1)
|
||||||
|
optimizer.step()
|
||||||
|
acc_loss += loss.item()
|
||||||
|
# return data, acc_loss + loss.item()
|
||||||
|
print_epoch(data,acc_loss,lst)
|
||||||
|
return acc_loss, data
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args, flush=True)
|
||||||
|
print(f'PID: {os.getpid():6d}', flush=True)
|
||||||
|
|
||||||
|
if isinstance(args.device, int) and args.device >= 0:
|
||||||
|
DEVICE = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
|
||||||
|
print(f'Using {DEVICE}')
|
||||||
|
else:
|
||||||
|
print('Device must be a zero or positive integer, but got',args.device)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if isinstance(args.seed, int) and args.seed >= 0:
|
||||||
|
seed = set_seed(args.seed)
|
||||||
|
else:
|
||||||
|
print('Seed must be a zero or positive integer, but got',args.seed)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# dataset_splits = make_splits('config/dataset_splits.json')
|
||||||
|
# dataset_splits = make_splits('config/dataset_splits_dev.json')
|
||||||
|
# dataset_splits = make_splits('config/dataset_splits_old.json')
|
||||||
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
||||||
|
|
||||||
|
if args.use_dialogue=='Yes':
|
||||||
|
d_flag = True
|
||||||
|
elif args.use_dialogue=='No':
|
||||||
|
d_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.use_dialogue_moves=='Yes':
|
||||||
|
d_move_flag = True
|
||||||
|
elif args.use_dialogue_moves=='No':
|
||||||
|
d_move_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not args.experiment in list(range(9)):
|
||||||
|
print('Experiment must be in',list(range(9)),', but got',args.experiment)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
if args.seq_model=='GRU':
|
||||||
|
seq_model = 0
|
||||||
|
elif args.seq_model=='LSTM':
|
||||||
|
seq_model = 1
|
||||||
|
elif args.seq_model=='Transformer':
|
||||||
|
seq_model = 2
|
||||||
|
else:
|
||||||
|
print('The sequence model must be in [GRU, LSTM, Transformer], but got', args.seq_model)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.plans=='Yes':
|
||||||
|
global_plan = (args.pov=='Third') or ((args.pov=='None') and (args.experiment in list(range(3))))
|
||||||
|
player_plan = (args.pov=='First') or ((args.pov=='None') and (args.experiment in list(range(3,9))))
|
||||||
|
elif args.plans=='No' or args.plans is None:
|
||||||
|
global_plan = False
|
||||||
|
player_plan = False
|
||||||
|
else:
|
||||||
|
print('Use Plan must be in [Yes, No], but got',args.plan)
|
||||||
|
exit()
|
||||||
|
print('global_plan', global_plan, 'player_plan', player_plan)
|
||||||
|
|
||||||
|
if args.pov=='None':
|
||||||
|
val = [GameParser(f,d_flag,0,0,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,0,0,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
val += [GameParser(f,d_flag,4,0,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train += [GameParser(f,d_flag,4,0,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
val = [GameParser(f,d_flag,3,0,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,3,0,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
val = [GameParser(f,d_flag,1,0,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,1,0,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
val += [GameParser(f,d_flag,2,0,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train += [GameParser(f,d_flag,2,0,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
model = Model(seq_model,DEVICE).to(DEVICE)
|
||||||
|
|
||||||
|
print(model)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
learning_rate = 1e-4
|
||||||
|
num_epochs = 1000#2#1#
|
||||||
|
weight_decay=1e-4
|
||||||
|
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
# optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
# optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
|
||||||
|
# optimizer = optim.Adagrad(model.parameters(), lr=learning_rate)
|
||||||
|
# optimizer = optim.Adadelta(model.parameters())
|
||||||
|
# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
|
||||||
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
# criterion = nn.MSELoss()
|
||||||
|
|
||||||
|
print(str(criterion), str(optimizer))
|
||||||
|
|
||||||
|
min_acc_loss = 100
|
||||||
|
max_f1 = 0
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
wait_epoch = 100
|
||||||
|
|
||||||
|
if args.model_path is not None:
|
||||||
|
print(f'Loading {args.model_path}')
|
||||||
|
model.load_state_dict(torch.load(args.model_path))
|
||||||
|
acc_loss, data = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, device=DEVICE)
|
||||||
|
data = list(zip(*data))
|
||||||
|
for x in data:
|
||||||
|
a, b = list(zip(*x))
|
||||||
|
f1 = f1_score(a,b,average='weighted')
|
||||||
|
f1 = f1_score(a,b,average='weighted')
|
||||||
|
if (max_f1 < f1):
|
||||||
|
max_f1 = f1
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
print('^')
|
||||||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
else:
|
||||||
|
print('Training model from scratch', flush=True)
|
||||||
|
# exit()
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
print(f'{os.getpid():6d} {epoch+1:4d},',end=' ', flush=True)
|
||||||
|
shuffle(train)
|
||||||
|
model.train()
|
||||||
|
do_split(model,train,args.experiment,criterion,optimizer=optimizer,global_plan=global_plan, player_plan=player_plan, device=DEVICE)
|
||||||
|
model.eval()
|
||||||
|
acc_loss, data = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, device=DEVICE)
|
||||||
|
|
||||||
|
data = list(zip(*data))
|
||||||
|
for x in data:
|
||||||
|
a, b = list(zip(*x))
|
||||||
|
f1 = f1_score(a,b,average='weighted')
|
||||||
|
if (max_f1 < f1):
|
||||||
|
max_f1 = f1
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
print('^')
|
||||||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
else:
|
||||||
|
epochs_since_improvement += 1
|
||||||
|
print()
|
||||||
|
# if (min_acc_loss > acc_loss):
|
||||||
|
# min_acc_loss = acc_loss
|
||||||
|
# epochs_since_improvement = 0
|
||||||
|
# print('^')
|
||||||
|
# else:
|
||||||
|
# epochs_since_improvement += 1
|
||||||
|
# print()
|
||||||
|
|
||||||
|
if epoch > wait_epoch and epochs_since_improvement > 20:
|
||||||
|
break
|
||||||
|
print()
|
||||||
|
print('Test')
|
||||||
|
model.load_state_dict(torch.load(args.save_path))
|
||||||
|
|
||||||
|
val = None
|
||||||
|
train = None
|
||||||
|
if args.pov=='None':
|
||||||
|
test = [GameParser(f,d_flag,0,0,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
test += [GameParser(f,d_flag,4,0,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
test = [GameParser(f,d_flag,3,0,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
test = [GameParser(f,d_flag,1,0,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
test += [GameParser(f,d_flag,2,0,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
_, data = do_split(model,test,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, device=DEVICE)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(data)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||||
|
parser.add_argument('--pov', type=str,
|
||||||
|
help='point of view [None, First, Third]')
|
||||||
|
parser.add_argument('--use_dialogue', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--use_dialogue_moves', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--plans', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--seq_model', type=str,
|
||||||
|
help='point of view [GRU, LSTM, Transformer]')
|
||||||
|
parser.add_argument('--experiment', type=int,
|
||||||
|
help='point of view [0:AggQ1, 1:AggQ2, 2:AggQ3, 3:P0Q1, 4:P0Q2, 5:P0Q3, 6:P1Q1, 7:P1Q2, 8:P1Q3]')
|
||||||
|
parser.add_argument('--seed', type=int,
|
||||||
|
help='Selet random seed by index [0, 1, 2, ...]. 0 -> random seed set to 0. n>0 -> random seed '
|
||||||
|
'set to n\'th random number with original seed set to 0')
|
||||||
|
parser.add_argument('--save_path', type=str,
|
||||||
|
help='path where to save model')
|
||||||
|
parser.add_argument('--model_path', type=str, default=None,
|
||||||
|
help='path to the pretrained model to be loaded')
|
||||||
|
parser.add_argument('--device', type=int, default=0,
|
||||||
|
help='select cuda device number')
|
||||||
|
|
||||||
|
main(parser.parse_args())
|
104
baselines_with_dialogue_moves_graphs.sh
Normal file
104
baselines_with_dialogue_moves_graphs.sh
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
FOLDER="models/gt_dialogue_moves_bootstrap_DlgMove_Graphs"
|
||||||
|
mkdir -p $FOLDER
|
||||||
|
CUDA_DEVICE=$1
|
||||||
|
SEED=$2
|
||||||
|
|
||||||
|
for MODEL in LSTM; do
|
||||||
|
for DLGM in Yes; do
|
||||||
|
for EXP in 6 7 8; do
|
||||||
|
|
||||||
|
DLG="No"
|
||||||
|
POV="None"
|
||||||
|
|
||||||
|
FILE_NAME="gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_${DLG}_pov_${POV}_exp${EXP}_seed_${SEED}"
|
||||||
|
COMM="baselines_with_dialogue_moves_graphs.py"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DLGM}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --pov=${POV}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T) $COMM" > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
|
||||||
|
DLG="Yes"
|
||||||
|
POV="None"
|
||||||
|
|
||||||
|
FILE_NAME="gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_${DLG}_pov_${POV}_exp${EXP}_seed_${SEED}"
|
||||||
|
COMM="baselines_with_dialogue_moves_graphs.py"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DLGM}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --pov=${POV}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_No_pov_None_exp${EXP}_seed_${SEED}.torch"
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T) $COMM" > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
|
||||||
|
DLG="No"
|
||||||
|
POV="First"
|
||||||
|
|
||||||
|
FILE_NAME="gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_${DLG}_pov_${POV}_exp${EXP}_seed_${SEED}"
|
||||||
|
COMM="baselines_with_dialogue_moves_graphs.py"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DLGM}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --pov=${POV}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_No_pov_None_exp${EXP}_seed_${SEED}.torch"
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T) $COMM" > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
|
||||||
|
DLG="Yes"
|
||||||
|
POV="First"
|
||||||
|
|
||||||
|
FILE_NAME="gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_${DLG}_pov_${POV}_exp${EXP}_seed_${SEED}_DlgFirst"
|
||||||
|
COMM="baselines_with_dialogue_moves_graphs.py"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DLGM}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --pov=${POV}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_Yes_pov_None_exp${EXP}_seed_${SEED}.torch"
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T) $COMM" > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
|
||||||
|
DLG="Yes"
|
||||||
|
POV="First"
|
||||||
|
|
||||||
|
FILE_NAME="gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_${DLG}_pov_${POV}_exp${EXP}_seed_${SEED}_VidFirst"
|
||||||
|
COMM="baselines_with_dialogue_moves_graphs.py"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DLGM}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --pov=${POV}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/gt_dialogue_moves_${MODEL}_dlgMove_${DLGM}_dlg_No_pov_First_exp${EXP}_seed_${SEED}.torch"
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T) $COMM" > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
|
||||||
|
echo "Done!"
|
480
compare_tom_cpa.ipynb
Normal file
480
compare_tom_cpa.ipynb
Normal file
File diff suppressed because one or more lines are too long
170
config/dataset_splits.json
Normal file
170
config/dataset_splits.json
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
{
|
||||||
|
"test": [
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_121618",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_170350",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_151811",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_161512",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_125416",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_165949",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_172012",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_130105",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_141752",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210405_152436",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_173743",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_121803",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_175421",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_133913",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_165909",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_165855",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210325_104621",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_162734",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210318_104917",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_173537",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_144127",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_130150",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_171303",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_141413",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_154431",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220411_150137",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_171257",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220411_144724",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_145152",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_142259",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_113656",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220411_141612",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_125550"
|
||||||
|
],
|
||||||
|
"validation": [
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_164019",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_173840",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_132732",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_131720",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_173627",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_153316",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_124510",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_130938",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_151438",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_173104",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_164427",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_153840",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210405_153941",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_172801",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_170448",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_175344",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210401_173413",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_104028",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_124434",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_174406",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_172923",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_174302",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_105324",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_120036",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_143947",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_122734",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_135406",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_104111",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_150003",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_155611",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_124131",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_105302",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_120221"
|
||||||
|
],
|
||||||
|
"training": [
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_100623",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_142313",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_122655",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_145310",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210318_102338",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_154347",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_171115",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_174339",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_152730",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_144034",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_172041",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_154530",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_173846",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_173317",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_194755",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_175855",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210401_173658",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_195348",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_124913",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210325_103805",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_192714",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_170855",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210325_102403",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_195827",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_175902",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_171045",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_171817",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_200825",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210401_172132",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_144657",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_170427",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_125317",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_103607",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_163839",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_163559",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_143013",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210405_153510",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_165333",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_143643",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_102828",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_171303",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_143419",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_150209",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_120302",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210405_154441",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_171107",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_163106",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_123714",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210401_172843",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_150329",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210318_105413",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_125749",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_104406",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_165255",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_164645",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210325_104220",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_170235",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210318_104417",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_171600",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_131000",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_111052",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_181018",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_114959",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_151402",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_114002",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_141329",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_153758",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_122939",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_103433",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_124701",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_105754",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_122435",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_102659",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_151443",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_170042",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_121258",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_150820",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_153152",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_154907",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_123642",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_104749",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_125051",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_152929",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_124204",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_141258",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_112213",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_140239",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220411_143329",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_163720",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220411_142447",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_114330",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_161135",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_155738",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_153821",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_123550",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_120630"
|
||||||
|
]
|
||||||
|
}
|
14
config/dataset_splits_dev.json
Normal file
14
config/dataset_splits_dev.json
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
{
|
||||||
|
"test": [
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_121618",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_170350"
|
||||||
|
],
|
||||||
|
"validation": [
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_164019",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_173840"
|
||||||
|
],
|
||||||
|
"training": [
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_100623",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_142313"
|
||||||
|
]
|
||||||
|
}
|
170
config/dataset_splits_new.json
Normal file
170
config/dataset_splits_new.json
Normal file
|
@ -0,0 +1,170 @@
|
||||||
|
{
|
||||||
|
"test": [
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_121618",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_170350",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_151811",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_161512",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_125416",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_165949",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_172012",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_130105",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_141752",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210405_152436",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_173743",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_121803",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_175421",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_133913",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_165909",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_165855",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210325_104621",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_162734",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210318_104917",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_173537",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_144127",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_130150",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_171303",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_141413",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_154431",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220411_150137",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_171257",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220411_144724",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_145152",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_142259",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_113656",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220411_141612",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_125550"
|
||||||
|
],
|
||||||
|
"validation": [
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_164019",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_173840",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_132732",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_131720",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_173627",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_153316",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_124510",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_130938",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_151438",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_173104",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_164427",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_153840",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210405_153941",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_172801",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_170448",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_175344",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210401_173413",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_104028",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_124434",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_174406",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_172923",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_174302",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_105324",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_120036",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_143947",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_122734",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_135406",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_104111",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_150003",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_155611",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_124131",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_105302",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_120221"
|
||||||
|
],
|
||||||
|
"training": [
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_100623",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_142313",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_122655",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_145310",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210318_102338",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_154347",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_171115",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_174339",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_152730",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_144034",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_172041",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_154530",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_173846",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_173317",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_194755",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_175855",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210401_173658",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_195348",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_124913",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210325_103805",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_192714",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_170855",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210325_102403",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_195827",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_175902",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_171045",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_171817",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_200825",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210401_172132",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_144657",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_170427",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_125317",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_103607",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_163839",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_163559",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_143013",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210405_153510",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_165333",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_143643",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_102828",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_171303",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_143419",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_150209",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_120302",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210405_154441",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_171107",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_163106",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_123714",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210401_172843",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_150329",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210318_105413",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_125749",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_104406",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_165255",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_164645",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210325_104220",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_170235",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210318_104417",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_171600",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_131000",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_111052",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_181018",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_114959",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_151402",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_114002",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_141329",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_153758",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_122939",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_103433",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_124701",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_105754",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_122435",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_102659",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_151443",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_170042",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_121258",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_150820",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_153152",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_154907",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_123642",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_104749",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_125051",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_152929",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_124204",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_141258",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_112213",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_140239",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220411_143329",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_163720",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220411_142447",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_114330",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220407_161135",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220412_155738",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220408_153821",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220414_123550",
|
||||||
|
"XXX/new_logs/141_212_110_53_20220415_120630"
|
||||||
|
]
|
||||||
|
}
|
110
config/dataset_splits_old.json
Normal file
110
config/dataset_splits_old.json
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
{
|
||||||
|
"test": [
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_121618",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_170350",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_151811",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_161512",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_125416",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_165949",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_172012",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_130105",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_141752",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210405_152436",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_173743",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_121803",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_175421",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_133913",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_165909",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_165855",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210325_104621",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_162734",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210318_104917",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_173537",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_144127"
|
||||||
|
],
|
||||||
|
"validation": [
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_164019",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_173840",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_132732",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_131720",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_173627",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_153316",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_124510",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_130938",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_151438",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_173104",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_164427",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_153840",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210405_153941",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_172801",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_170448",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_175344",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210401_173413",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_104028",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_124434",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_174406",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_172923"
|
||||||
|
],
|
||||||
|
"training": [
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_100623",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_142313",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_122655",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_145310",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210318_102338",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_154347",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_171115",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_174339",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_152730",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_144034",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_172041",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_154530",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_173846",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_173317",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_194755",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_175855",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210401_173658",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_195348",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_124913",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210325_103805",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_192714",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_170855",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210325_102403",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_195827",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_175902",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_171045",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_171817",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210413_200825",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210401_172132",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_144657",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_170427",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_125317",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_103607",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_163839",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_163559",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_143013",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210405_153510",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_165333",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210322_143643",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_102828",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_171303",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_143419",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210429_150209",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_120302",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210405_154441",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_171107",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210408_163106",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_123714",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210401_172843",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210423_150329",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210318_105413",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210325_125749",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210409_104406",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210430_165255",
|
||||||
|
"XXX/main_logs/172_31_25_15_20210428_164645",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210325_104220",
|
||||||
|
"XXX/main_logs/141_212_108_99_20210407_170235",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210318_104417",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210323_171600",
|
||||||
|
"XXX/saved_logs/141_212_108_99_20210317_131000"
|
||||||
|
]
|
||||||
|
}
|
21
config/dialogue_act_label_names.json
Normal file
21
config/dialogue_act_label_names.json
Normal file
|
@ -0,0 +1,21 @@
|
||||||
|
{
|
||||||
|
"ACKNOWLEDGMENT": 0,
|
||||||
|
"AGREEMENT": 1,
|
||||||
|
"APOLOGY": 2,
|
||||||
|
"AnsAff": 3,
|
||||||
|
"AnsNeg": 4,
|
||||||
|
"AnsOth": 5,
|
||||||
|
"AnsOther": 6,
|
||||||
|
"BACKCHANNEL": 7,
|
||||||
|
"CLOSING": 8,
|
||||||
|
"DIRECTIVE": 9,
|
||||||
|
"DeclarativeQuestion": 10,
|
||||||
|
"GameSpec": 11,
|
||||||
|
"OPENING": 12,
|
||||||
|
"OPINION": 13,
|
||||||
|
"OrClause": 14,
|
||||||
|
"STATEMENT": 15,
|
||||||
|
"WhQuestion": 16,
|
||||||
|
"YesNoQuestion": 17,
|
||||||
|
"other": 18
|
||||||
|
}
|
3629
config/dialogue_act_labels.json
Normal file
3629
config/dialogue_act_labels.json
Normal file
File diff suppressed because it is too large
Load diff
37
config/dialogue_move_label_names.json
Normal file
37
config/dialogue_move_label_names.json
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
{
|
||||||
|
"ACKNOWLEDGEMENT": 0,
|
||||||
|
"AGREEMENT": 1,
|
||||||
|
"APOLOGY": 2,
|
||||||
|
"AnsAff": 3,
|
||||||
|
"AnsNeg": 4,
|
||||||
|
"AnsOth": 5,
|
||||||
|
"BACKCHANNEL": 6,
|
||||||
|
"CLOSING": 7,
|
||||||
|
"Directive-Make": 8,
|
||||||
|
"Directive-Other": 9,
|
||||||
|
"Directive-PickUp": 10,
|
||||||
|
"Directive-PutDown": 11,
|
||||||
|
"Directive-PutOn": 12,
|
||||||
|
"GameSpec": 13,
|
||||||
|
"Inquiry-Act": 14,
|
||||||
|
"Inquiry-Goal": 15,
|
||||||
|
"Inquiry-NextStep": 16,
|
||||||
|
"Inquiry-OwnAct": 17,
|
||||||
|
"Inquiry-Possession": 18,
|
||||||
|
"Inquiry-Recipe": 19,
|
||||||
|
"Inquiry-Requirement": 20,
|
||||||
|
"OPENING": 21,
|
||||||
|
"OPINION": 22,
|
||||||
|
"OrClause": 23,
|
||||||
|
"Statement-Goal": 24,
|
||||||
|
"Statement-Inability": 25,
|
||||||
|
"Statement-LackKnowledge": 26,
|
||||||
|
"Statement-NextStep": 27,
|
||||||
|
"Statement-Other": 28,
|
||||||
|
"Statement-OwnAct": 29,
|
||||||
|
"Statement-Possession": 30,
|
||||||
|
"Statement-Recipe": 31,
|
||||||
|
"Statement-Requirement": 32,
|
||||||
|
"Statement-StepDone": 33,
|
||||||
|
"other": 34
|
||||||
|
}
|
3629
config/dialogue_move_labels.json
Normal file
3629
config/dialogue_move_labels.json
Normal file
File diff suppressed because it is too large
Load diff
23
config/materials.json
Normal file
23
config/materials.json
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
[
|
||||||
|
"BLACK_WOOL",
|
||||||
|
"BLUE_WOOL",
|
||||||
|
"BROWN_WOOL",
|
||||||
|
"COBBLESTONE",
|
||||||
|
"CYAN_WOOL",
|
||||||
|
"DIAMOND_BLOCK",
|
||||||
|
"EMERALD_BLOCK",
|
||||||
|
"GOLD_BLOCK",
|
||||||
|
"GRAY_WOOL",
|
||||||
|
"GREEN_WOOL",
|
||||||
|
"IRON_BLOCK",
|
||||||
|
"LAPIS_BLOCK",
|
||||||
|
"LIME_WOOL",
|
||||||
|
"MAGENTA_WOOL",
|
||||||
|
"OBSIDIAN",
|
||||||
|
"ORANGE_WOOL",
|
||||||
|
"REDSTONE_BLOCK",
|
||||||
|
"RED_WOOL",
|
||||||
|
"SOUL_SAND",
|
||||||
|
"WHITE_WOOL",
|
||||||
|
"YELLOW_WOOL"
|
||||||
|
]
|
8
config/mines.json
Normal file
8
config/mines.json
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
[
|
||||||
|
"ACACIA_PLANKS",
|
||||||
|
"BIRCH_PLANKS",
|
||||||
|
"DARK_OAK_PLANKS",
|
||||||
|
"JUNGLE_PLANKS",
|
||||||
|
"OAK_PLANKS",
|
||||||
|
"SPRUCE_PLANKS"
|
||||||
|
]
|
14
config/tools.json
Normal file
14
config/tools.json
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
[
|
||||||
|
"DIAMOND_AXE",
|
||||||
|
"DIAMOND_HOE",
|
||||||
|
"DIAMOND_PICKAXE",
|
||||||
|
"DIAMOND_SHOVEL",
|
||||||
|
"GOLDEN_AXE",
|
||||||
|
"GOLDEN_HOE",
|
||||||
|
"GOLDEN_PICKAXE",
|
||||||
|
"GOLDEN_SHOVEL",
|
||||||
|
"IRON_AXE",
|
||||||
|
"IRON_HOE",
|
||||||
|
"IRON_PICKAXE",
|
||||||
|
"IRON_SHOVEL"
|
||||||
|
]
|
75
intermediate_representations.py
Normal file
75
intermediate_representations.py
Normal file
|
@ -0,0 +1,75 @@
|
||||||
|
from src.models.model_with_dialogue_moves import Model as ToMModel
|
||||||
|
from src.models.plan_model_graphs import Model as CPAModel
|
||||||
|
from src.data.game_parser import GameParser
|
||||||
|
from src.data.game_parser_graphs_new import GameParser as GameParserCPA
|
||||||
|
import torch
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
MODEL_TYPES = {
|
||||||
|
'GRU' : 0,
|
||||||
|
'LSTM' : 1,
|
||||||
|
'Transformer' : 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
def main():
|
||||||
|
|
||||||
|
model_file = "models/tom_lstm_baseline/tom6_model.torch"
|
||||||
|
use_dialogue = "Yes"
|
||||||
|
model_type_name = "LSTM"
|
||||||
|
model_type = MODEL_TYPES[model_type_name]
|
||||||
|
model = ToMModel(model_type).to(DEVICE)
|
||||||
|
model.load_state_dict(torch.load(model_file))
|
||||||
|
dataset_splits = json.load(open('config/dataset_splits.json'))
|
||||||
|
for set in dataset_splits.values():
|
||||||
|
for path in set:
|
||||||
|
for pov in [1, 2]:
|
||||||
|
out_file = f'{path}/intermediate_ToM6_{path.split("/")[-1]}_player{pov}.npz'
|
||||||
|
# if os.path.isfile(out_file):
|
||||||
|
# continue
|
||||||
|
game = GameParser(path,use_dialogue=='Yes',pov,0,True)
|
||||||
|
l = model(game, global_plan=False, player_plan=True,intermediate=True).cpu().data.numpy()
|
||||||
|
np.savez_compressed(open(out_file,'wb'), data=l)
|
||||||
|
print(out_file,l.shape,model_type_name,use_dialogue,use_dialogue=='Yes')
|
||||||
|
|
||||||
|
model_file = "models/tom_lstm_baseline/tom7_model.torch"
|
||||||
|
use_dialogue = "Yes"
|
||||||
|
model_type_name = 'LSTM'
|
||||||
|
model_type = MODEL_TYPES[model_type_name]
|
||||||
|
model = ToMModel(model_type).to(DEVICE)
|
||||||
|
model.load_state_dict(torch.load(model_file))
|
||||||
|
dataset_splits = json.load(open('config/dataset_splits.json'))
|
||||||
|
for set in dataset_splits.values():
|
||||||
|
for path in set:
|
||||||
|
for pov in [1, 2]:
|
||||||
|
out_file = f'{path}/intermediate_ToM7_{path.split("/")[-1]}_player{pov}.npz'
|
||||||
|
# if os.path.isfile(out_file):
|
||||||
|
# continue
|
||||||
|
game = GameParser(path,use_dialogue=='Yes',4,0,True)
|
||||||
|
l = model(game, global_plan=False, player_plan=True,intermediate=True).cpu().data.numpy()
|
||||||
|
np.savez_compressed(open(out_file,'wb'), data=l)
|
||||||
|
print(out_file,l.shape,model_type_name,use_dialogue,use_dialogue=='Yes')
|
||||||
|
|
||||||
|
model_file = "models/tom_lstm_baseline/tom8_model.torch"
|
||||||
|
use_dialogue = "Yes"
|
||||||
|
model_type_name = 'LSTM'
|
||||||
|
model_type = MODEL_TYPES[model_type_name]
|
||||||
|
model = ToMModel(model_type).to(DEVICE)
|
||||||
|
model.load_state_dict(torch.load(model_file))
|
||||||
|
dataset_splits = json.load(open('config/dataset_splits.json'))
|
||||||
|
for set in dataset_splits.values():
|
||||||
|
for path in set:
|
||||||
|
for pov in [1, 2]:
|
||||||
|
out_file = f'{path}/intermediate_ToM8_{path.split("/")[-1]}_player{pov}.npz'
|
||||||
|
# if os.path.isfile(out_file):
|
||||||
|
# continue
|
||||||
|
game = GameParser(path,use_dialogue=='Yes',4,True)
|
||||||
|
l = model(game, global_plan=False, player_plan=True,intermediate=True).cpu().data.numpy()
|
||||||
|
np.savez_compressed(open(out_file,'wb'), data=l)
|
||||||
|
print(out_file,l.shape,model_type_name,use_dialogue,use_dialogue=='Yes')
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
main()
|
383
logistic_regression_tom_feats.py
Normal file
383
logistic_regression_tom_feats.py
Normal file
|
@ -0,0 +1,383 @@
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
import pickle
|
||||||
|
import os
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
from sklearn.metrics import accuracy_score, classification_report, f1_score
|
||||||
|
from src.data.game_parser_graphs_new import GameParser, make_splits, onehot, set_seed
|
||||||
|
from tqdm import tqdm
|
||||||
|
from scipy.stats import wilcoxon
|
||||||
|
from sklearn.exceptions import ConvergenceWarning
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from sklearn.manifold import TSNE
|
||||||
|
from sklearn.preprocessing import StandardScaler
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
import umap
|
||||||
|
import warnings
|
||||||
|
warnings.simplefilter("ignore", category=ConvergenceWarning)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_q(q, game):
|
||||||
|
if not q is None:
|
||||||
|
q ,l = q
|
||||||
|
q = np.concatenate([
|
||||||
|
onehot(q[2],2),
|
||||||
|
onehot(q[3],2),
|
||||||
|
onehot(q[4][0][0]+1,2),
|
||||||
|
onehot(game.materials_dict[q[4][0][1]],len(game.materials_dict)),
|
||||||
|
onehot(q[4][1][0]+1,2),
|
||||||
|
onehot(game.materials_dict[q[4][1][1]],len(game.materials_dict)),
|
||||||
|
onehot(q[4][2]+1,2),
|
||||||
|
onehot(q[5][0][0]+1,2),
|
||||||
|
onehot(game.materials_dict[q[5][0][1]],len(game.materials_dict)),
|
||||||
|
onehot(q[5][1][0]+1,2),
|
||||||
|
onehot(game.materials_dict[q[5][1][1]],len(game.materials_dict)),
|
||||||
|
onehot(q[5][2]+1,2)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
q = np.zeros(100)
|
||||||
|
l = None
|
||||||
|
return q, l
|
||||||
|
|
||||||
|
def cosine_similarity(array1, array2):
|
||||||
|
"""
|
||||||
|
Compute the cosine similarity between two arrays.
|
||||||
|
Parameters:
|
||||||
|
- array1: First input array
|
||||||
|
- array2: Second input array
|
||||||
|
Returns:
|
||||||
|
- similarity: Cosine similarity between the two arrays
|
||||||
|
"""
|
||||||
|
dot_product = np.dot(array1, array2)
|
||||||
|
norm_array1 = np.linalg.norm(array1)
|
||||||
|
norm_array2 = np.linalg.norm(array2)
|
||||||
|
similarity = dot_product / (norm_array1 * norm_array2)
|
||||||
|
return similarity
|
||||||
|
|
||||||
|
def compute_and_plot_pca(data1, data2, labels=None, fname='pca'):
|
||||||
|
"""
|
||||||
|
Compute and plot Principal Component Analysis (PCA) for a given dataset with 2 components.
|
||||||
|
Parameters:
|
||||||
|
- data: Input dataset
|
||||||
|
- labels: Labels for data points (optional)
|
||||||
|
Returns:
|
||||||
|
- pca_result: Result of PCA transformation
|
||||||
|
"""
|
||||||
|
scaler1 = StandardScaler()
|
||||||
|
data_standardized1 = scaler1.fit_transform(data1)
|
||||||
|
scaler2 = StandardScaler()
|
||||||
|
data_standardized2 = scaler2.fit_transform(data2)
|
||||||
|
pca1 = PCA(n_components=2)
|
||||||
|
pca_result1 = pca1.fit_transform(data_standardized1)
|
||||||
|
pca2 = PCA(n_components=2)
|
||||||
|
pca_result2 = pca2.fit_transform(data_standardized2)
|
||||||
|
pca_result = np.concatenate([pca_result1, pca_result2])
|
||||||
|
unique_labels = np.unique(labels) if labels is not None else [None]
|
||||||
|
plt.figure(figsize=(8, 6))
|
||||||
|
for unique_label in unique_labels:
|
||||||
|
mask = (labels == unique_label) if labels is not None else slice(None)
|
||||||
|
plt.scatter(pca_result[mask, 0], pca_result[mask, 1], label=unique_label)
|
||||||
|
plt.xlabel('Principal Component 1')
|
||||||
|
plt.ylabel('Principal Component 2')
|
||||||
|
if labels is not None:
|
||||||
|
plt.legend()
|
||||||
|
os.makedirs("figures/", exist_ok=True)
|
||||||
|
plt.savefig(f"figures/{fname}.pdf", bbox_inches='tight')
|
||||||
|
return pca_result
|
||||||
|
|
||||||
|
def compute_and_plot_tsne(data1, data2, labels=None, fname='tsne'):
|
||||||
|
"""
|
||||||
|
Compute and plot t-SNE for a given standardized dataset with 2 components.
|
||||||
|
Parameters:
|
||||||
|
- data: Input dataset
|
||||||
|
- labels: Labels for data points (optional)
|
||||||
|
Returns:
|
||||||
|
- tsne_result: Result of t-SNE transformation
|
||||||
|
"""
|
||||||
|
scaler1 = StandardScaler()
|
||||||
|
data_standardized1 = scaler1.fit_transform(data1)
|
||||||
|
tsne1 = TSNE(n_components=2)
|
||||||
|
tsne_result1 = tsne1.fit_transform(data_standardized1)
|
||||||
|
scaler2 = StandardScaler()
|
||||||
|
data_standardized2 = scaler2.fit_transform(data2)
|
||||||
|
tsne2 = TSNE(n_components=2)
|
||||||
|
tsne_result2 = tsne2.fit_transform(data_standardized2)
|
||||||
|
tsne_result = np.concatenate([tsne_result1, tsne_result2])
|
||||||
|
unique_labels = np.unique(labels) if labels is not None else [None]
|
||||||
|
plt.figure(figsize=(8, 6))
|
||||||
|
for unique_label in unique_labels:
|
||||||
|
mask = (labels == unique_label) if labels is not None else slice(None)
|
||||||
|
plt.scatter(tsne_result[mask, 0], tsne_result[mask, 1], label=unique_label)
|
||||||
|
plt.xlabel('t-SNE Component 1')
|
||||||
|
plt.ylabel('t-SNE Component 2')
|
||||||
|
if labels is not None:
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(f"figures/{fname}.pdf", bbox_inches='tight')
|
||||||
|
return tsne_result
|
||||||
|
|
||||||
|
def compute_and_plot_umap(data1, data2, labels=None, fname='umap'):
|
||||||
|
"""
|
||||||
|
Compute and plot UMAP for a given standardized dataset with 2 components.
|
||||||
|
Parameters:
|
||||||
|
- data: Input dataset
|
||||||
|
- labels: Labels for data points (optional)
|
||||||
|
Returns:
|
||||||
|
- umap_result: Result of UMAP transformation
|
||||||
|
"""
|
||||||
|
scaler1 = StandardScaler()
|
||||||
|
data_standardized1 = scaler1.fit_transform(data1)
|
||||||
|
umap_model1 = umap.UMAP(n_components=2)
|
||||||
|
umap_result1 = umap_model1.fit_transform(data_standardized1)
|
||||||
|
scaler2 = StandardScaler()
|
||||||
|
data_standardized2 = scaler2.fit_transform(data2)
|
||||||
|
umap_model2 = umap.UMAP(n_components=2)
|
||||||
|
umap_result2 = umap_model2.fit_transform(data_standardized2)
|
||||||
|
umap_result = np.concatenate([umap_result1, umap_result2])
|
||||||
|
unique_labels = np.unique(labels) if labels is not None else [None]
|
||||||
|
plt.figure(figsize=(8, 6))
|
||||||
|
for unique_label in unique_labels:
|
||||||
|
mask = (labels == unique_label) if labels is not None else slice(None)
|
||||||
|
plt.scatter(umap_result[mask, 0], umap_result[mask, 1], label=unique_label)
|
||||||
|
plt.xlabel('UMAP Component 1')
|
||||||
|
plt.ylabel('UMAP Component 2')
|
||||||
|
if labels is not None:
|
||||||
|
plt.legend()
|
||||||
|
plt.savefig(f"figures/{fname}.pdf", bbox_inches='tight')
|
||||||
|
return umap_result
|
||||||
|
|
||||||
|
def prepare_data_tom(mode):
|
||||||
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
||||||
|
d_flag = False
|
||||||
|
d_move_flag = False
|
||||||
|
if mode == 'train':
|
||||||
|
data = [GameParser(f,d_flag,1,7,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
data += [GameParser(f,d_flag,2,7,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif mode == 'test':
|
||||||
|
data = [GameParser(f,d_flag,1,7,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
data += [GameParser(f,d_flag,2,7,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
else:
|
||||||
|
raise ValueError('train or test are supported')
|
||||||
|
tom6repr = []
|
||||||
|
tom7repr = []
|
||||||
|
tom8repr = []
|
||||||
|
tom6labels = []
|
||||||
|
tom7labels = []
|
||||||
|
tom8labels = []
|
||||||
|
for game in data:
|
||||||
|
_, _, _, q, _, _, interm, _ = zip(*list(game))
|
||||||
|
interm = np.array(interm)
|
||||||
|
# intermediate = np.concatenate([ToM6,ToM7,ToM8,DAct,DMove])
|
||||||
|
tom6, tom7, tom8, _, _ = np.split(interm, np.cumsum([1024] * 5)[:-1], axis=1)
|
||||||
|
q = [parse_q(x, game) for x in q]
|
||||||
|
q, l = zip(*q)
|
||||||
|
indexes = [idx for idx, element in enumerate(l) if element is not None]
|
||||||
|
tom6repr.append(tom6[indexes])
|
||||||
|
tom7repr.append(tom7[indexes])
|
||||||
|
tom8repr.append(tom8[indexes])
|
||||||
|
l = [item[1] for item in l if item is not None]
|
||||||
|
tom6labels.append([['NO', 'MAYBE', 'YES'].index(item[0]) for item in l])
|
||||||
|
tom7labels.append([['NO', 'MAYBE', 'YES'].index(item[1]) for item in l])
|
||||||
|
tom8labels.append([game.materials_dict[item[2]] if item[2] in game.materials_dict else 0 for item in l])
|
||||||
|
tom6labels = sum(tom6labels, [])
|
||||||
|
tom7labels = sum(tom7labels, [])
|
||||||
|
tom8labels = sum(tom8labels, [])
|
||||||
|
return np.concatenate(tom6repr), tom6labels, np.concatenate(tom7repr), tom7labels, np.concatenate(tom8repr), tom8labels
|
||||||
|
|
||||||
|
def prepare_data_cpa(mode, experiment):
|
||||||
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
||||||
|
d_flag = False
|
||||||
|
d_move_flag = False
|
||||||
|
if mode == "train":
|
||||||
|
data = [GameParser(f,d_flag,1,7,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
data += [GameParser(f,d_flag,2,7,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
if experiment == 2:
|
||||||
|
with open('XXX', 'rb') as f:
|
||||||
|
feats = pickle.load(f)
|
||||||
|
elif experiment == 3:
|
||||||
|
with open('XXX', 'rb') as f:
|
||||||
|
feats = pickle.load(f)
|
||||||
|
else: raise ValueError
|
||||||
|
elif mode == "test":
|
||||||
|
data = [GameParser(f,d_flag,1,7,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
data += [GameParser(f,d_flag,2,7,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
if experiment == 2:
|
||||||
|
with open('XXX', 'rb') as f:
|
||||||
|
feats = pickle.load(f)
|
||||||
|
elif experiment == 3:
|
||||||
|
with open('XXX', 'rb') as f:
|
||||||
|
feats = pickle.load(f)
|
||||||
|
else: raise ValueError
|
||||||
|
else:
|
||||||
|
raise ValueError('train or test are supported')
|
||||||
|
tom6labels = []
|
||||||
|
tom7labels = []
|
||||||
|
tom8labels = []
|
||||||
|
features = [item[0] for item in feats]
|
||||||
|
game_names = [item[1] for item in feats]
|
||||||
|
selected_feats = []
|
||||||
|
for i, game in enumerate(data):
|
||||||
|
_, _, _, q, _, _, _, _ = zip(*list(game))
|
||||||
|
q = [parse_q(x, game) for x in q]
|
||||||
|
q, l = zip(*q)
|
||||||
|
indexes = [idx for idx, element in enumerate(l) if element is not None]
|
||||||
|
assert game.game_path.split("/")[-1] == game_names[i].split("/")[-1]
|
||||||
|
selected_feats.append(features[i][indexes])
|
||||||
|
l = [item[1] for item in l if item is not None]
|
||||||
|
tom6labels.append([['NO', 'MAYBE', 'YES'].index(item[0]) for item in l])
|
||||||
|
tom7labels.append([['NO', 'MAYBE', 'YES'].index(item[1]) for item in l])
|
||||||
|
tom8labels.append([game.materials_dict[item[2]] if item[2] in game.materials_dict else 0 for item in l])
|
||||||
|
tom6labels = sum(tom6labels, [])
|
||||||
|
tom7labels = sum(tom7labels, [])
|
||||||
|
tom8labels = sum(tom8labels, [])
|
||||||
|
selected_feats = np.concatenate(selected_feats)
|
||||||
|
return selected_feats, tom6labels, tom7labels, tom8labels
|
||||||
|
|
||||||
|
def fit_and_test_LR(X_train, y_train, X_test, y_test, max_iter=100):
|
||||||
|
logreg_model = LogisticRegression(max_iter=max_iter)
|
||||||
|
logreg_model.fit(X_train, y_train)
|
||||||
|
y_pred = logreg_model.predict(X_test)
|
||||||
|
f1 = f1_score(y_test, y_pred, average="weighted", zero_division=1)
|
||||||
|
# classification_report_output = classification_report(y_test, y_pred)
|
||||||
|
print("F1 score:", f1)
|
||||||
|
# print("Classification Report:\n", classification_report_output)
|
||||||
|
return logreg_model
|
||||||
|
|
||||||
|
def fit_and_test_RF(X_train, y_train, X_test, y_test, n_estimators):
|
||||||
|
model = RandomForestClassifier(n_estimators=n_estimators)
|
||||||
|
model.fit(X_train, y_train)
|
||||||
|
y_pred = model.predict(X_test)
|
||||||
|
f1 = f1_score(y_test, y_pred, average="weighted", zero_division=1)
|
||||||
|
print("F1 score:", f1)
|
||||||
|
return model
|
||||||
|
|
||||||
|
def wilcoxon_test(model1, model2, X_test_1, X_test_2):
|
||||||
|
probabilities_model1 = model1.predict_proba(X_test_1)[:, 1]
|
||||||
|
probabilities_model2 = model2.predict_proba(X_test_2)[:, 1]
|
||||||
|
differences = probabilities_model1 - probabilities_model2
|
||||||
|
_, p_value_wilcoxon = wilcoxon(differences)
|
||||||
|
print("Wilcoxon signed-rank test p-value:", p_value_wilcoxon)
|
||||||
|
return p_value_wilcoxon
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("--task", type=str)
|
||||||
|
parser.add_argument("--seed", type=int)
|
||||||
|
parser.add_argument("--experiment", type=int)
|
||||||
|
args = parser.parse_args()
|
||||||
|
set_seed(args.seed)
|
||||||
|
task = args.task
|
||||||
|
experiment = args.experiment
|
||||||
|
|
||||||
|
if task == "tom":
|
||||||
|
tom6_train_x, tom6_train_labels, tom7_train_x, tom7_train_labels, tom8_train_x, tom8_train_labels = prepare_data_tom("train")
|
||||||
|
tom6_test_x, tom6_test_labels, tom7_test_x, tom7_test_labels, tom8_test_x, tom8_test_labels = prepare_data_tom("test")
|
||||||
|
print("=========== EXP 6 ========================================")
|
||||||
|
# 0.6056079527261083
|
||||||
|
fit_and_test_LR(tom6_train_x, tom6_train_labels, tom6_test_x, tom6_test_labels, 100)
|
||||||
|
print("=========== EXP 7 ========================================")
|
||||||
|
# 0.5090737845776365
|
||||||
|
fit_and_test_LR(tom7_train_x, tom7_train_labels, tom7_test_x, tom7_test_labels, 100)
|
||||||
|
print("=========== EXP 8 ========================================")
|
||||||
|
# 0.10206891928130866
|
||||||
|
fit_and_test_LR(tom8_train_x, tom8_train_labels, tom8_test_x, tom8_test_labels, 6)
|
||||||
|
breakpoint()
|
||||||
|
|
||||||
|
elif task == "cpa":
|
||||||
|
train_x, tom6_train_labels, tom7_train_labels, tom8_train_labels = prepare_data_cpa("train", experiment)
|
||||||
|
test_x, tom6_test_labels, tom7_test_labels, tom8_test_labels = prepare_data_cpa("test", experiment)
|
||||||
|
print("=========== EXP 6 ========================================")
|
||||||
|
# 0.5157497361676466 139
|
||||||
|
fit_and_test_LR(train_x, tom6_train_labels, test_x, tom6_test_labels, 139 if experiment == 2 else 11)
|
||||||
|
print("=========== EXP 7 ========================================")
|
||||||
|
# 0.49755418256915795 307
|
||||||
|
fit_and_test_LR(train_x, tom7_train_labels, test_x, tom7_test_labels, 307 if experiment == 2 else 25)
|
||||||
|
print("=========== EXP 8 ========================================")
|
||||||
|
# 0.14099639490943838 23
|
||||||
|
fit_and_test_LR(train_x, tom8_train_labels, test_x, tom8_test_labels, 23 if experiment == 2 else 9)
|
||||||
|
breakpoint()
|
||||||
|
|
||||||
|
elif task == "random":
|
||||||
|
tom6_train_x, tom6_train_labels, tom7_train_x, tom7_train_labels, tom8_train_x, tom8_train_labels = prepare_data_tom("train")
|
||||||
|
tom6_test_x, tom6_test_labels, tom7_test_x, tom7_test_labels, tom8_test_x, tom8_test_labels = prepare_data_tom("test")
|
||||||
|
tom6_train_x = np.random.randn(*tom6_train_x.shape) * 0.1
|
||||||
|
tom7_train_x = np.random.randn(*tom7_train_x.shape) * 0.1
|
||||||
|
tom8_train_x = np.random.randn(*tom8_train_x.shape) * 0.1
|
||||||
|
tom6_test_x = np.random.randn(*tom6_test_x.shape) * 0.1
|
||||||
|
tom7_test_x = np.random.randn(*tom7_test_x.shape) * 0.1
|
||||||
|
tom8_test_x = np.random.randn(*tom8_test_x.shape) * 0.1
|
||||||
|
print("=========== EXP 6 ========================================")
|
||||||
|
# 0.4573518645097593
|
||||||
|
fit_and_test_LR(tom6_train_x, tom6_train_labels, tom6_test_x, tom6_test_labels, 100)
|
||||||
|
print("=========== EXP 7 ========================================")
|
||||||
|
# 0.45066310491597705
|
||||||
|
fit_and_test_LR(tom7_train_x, tom7_train_labels, tom7_test_x, tom7_test_labels, 100)
|
||||||
|
print("=========== EXP 8 ========================================")
|
||||||
|
# 0.09281225255303022
|
||||||
|
fit_and_test_LR(tom8_train_x, tom8_train_labels, tom8_test_x, tom8_test_labels, 100)
|
||||||
|
breakpoint()
|
||||||
|
|
||||||
|
elif task == "all":
|
||||||
|
############## TOM
|
||||||
|
print("############## TOM")
|
||||||
|
tom6_train_x_tom, tom6_train_labels_tom, tom7_train_x_tom, tom7_train_labels_tom, tom8_train_x_tom, tom8_train_labels_tom = prepare_data_tom("train")
|
||||||
|
tom6_test_x_tom, tom6_test_labels_tom, tom7_test_x_tom, tom7_test_labels_tom, tom8_test_x_tom, tom8_test_labels_tom = prepare_data_tom("test")
|
||||||
|
print("=========== EXP 6 ========================================")
|
||||||
|
model_tom_6 = fit_and_test_LR(tom6_train_x_tom, tom6_train_labels_tom, tom6_test_x_tom, tom6_test_labels_tom, 100)
|
||||||
|
print("=========== EXP 7 ========================================")
|
||||||
|
model_tom_7 = fit_and_test_LR(tom7_train_x_tom, tom7_train_labels_tom, tom7_test_x_tom, tom7_test_labels_tom, 100)
|
||||||
|
print("=========== EXP 8 ========================================")
|
||||||
|
model_tom_8 = fit_and_test_LR(tom8_train_x_tom, tom8_train_labels_tom, tom8_test_x_tom, tom8_test_labels_tom, 6)
|
||||||
|
############## CPA
|
||||||
|
print("############## CPA")
|
||||||
|
train_x_cpa, tom6_train_labels_cpa, tom7_train_labels_cpa, tom8_train_labels_cpa = prepare_data_cpa("train", experiment)
|
||||||
|
test_x_cpa, tom6_test_labels_cpa, tom7_test_labels_cpa, tom8_test_labels_cpa = prepare_data_cpa("test", experiment)
|
||||||
|
print("=========== EXP 6 ========================================")
|
||||||
|
model_cpa_6 = fit_and_test_LR(train_x_cpa, tom6_train_labels_cpa, test_x_cpa, tom6_test_labels_cpa, 139)
|
||||||
|
print("=========== EXP 7 ========================================")
|
||||||
|
model_cpa_7 = fit_and_test_LR(train_x_cpa, tom7_train_labels_cpa, test_x_cpa, tom7_test_labels_cpa, 307)
|
||||||
|
print("=========== EXP 8 ========================================")
|
||||||
|
model_cpa_8 = fit_and_test_LR(train_x_cpa, tom8_train_labels_cpa, test_x_cpa, tom8_test_labels_cpa, 23)
|
||||||
|
############## RANDOM
|
||||||
|
print("############## RANDOM")
|
||||||
|
tom6_train_x_rand = np.random.randn(*tom6_train_x_tom.shape) * 0.1
|
||||||
|
tom7_train_x_rand = np.random.randn(*tom7_train_x_tom.shape) * 0.1
|
||||||
|
tom8_train_x_rand = np.random.randn(*tom8_train_x_tom.shape) * 0.1
|
||||||
|
tom6_test_x_rand = np.random.randn(*tom6_test_x_tom.shape) * 0.1
|
||||||
|
tom7_test_x_rand = np.random.randn(*tom7_test_x_tom.shape) * 0.1
|
||||||
|
tom8_test_x_rand = np.random.randn(*tom8_test_x_tom.shape) * 0.1
|
||||||
|
print("=========== EXP 6 ========================================")
|
||||||
|
model_rand_6 = fit_and_test_LR(tom6_train_x_rand, tom6_train_labels_tom, tom6_test_x_rand, tom6_test_labels_tom, 100)
|
||||||
|
print("=========== EXP 7 ========================================")
|
||||||
|
model_rand_7 = fit_and_test_LR(tom7_train_x_rand, tom7_train_labels_tom, tom7_test_x_rand, tom7_test_labels_tom, 100)
|
||||||
|
print("=========== EXP 8 ========================================")
|
||||||
|
model_rand_8 = fit_and_test_LR(tom8_train_x_rand, tom8_train_labels_tom, tom8_test_x_rand, tom8_test_labels_tom, 100)
|
||||||
|
wilcoxon_test(model_tom_6, model_cpa_6, tom6_test_x_tom, test_x_cpa)
|
||||||
|
wilcoxon_test(model_rand_6, model_cpa_6, tom6_test_x_rand, test_x_cpa)
|
||||||
|
wilcoxon_test(model_rand_6, model_tom_6, tom6_test_x_rand, tom6_test_x_tom)
|
||||||
|
wilcoxon_test(model_tom_7, model_cpa_7, tom7_test_x_tom, test_x_cpa)
|
||||||
|
wilcoxon_test(model_rand_7, model_cpa_7, tom7_test_x_rand, test_x_cpa)
|
||||||
|
wilcoxon_test(model_rand_7, model_tom_7, tom7_test_x_rand, tom7_test_x_tom)
|
||||||
|
wilcoxon_test(model_tom_8, model_cpa_8, tom8_test_x_tom, test_x_cpa)
|
||||||
|
wilcoxon_test(model_rand_8, model_cpa_8, tom8_test_x_rand, test_x_cpa)
|
||||||
|
wilcoxon_test(model_rand_8, model_tom_8, tom8_test_x_rand, tom8_test_x_tom)
|
||||||
|
scaler = StandardScaler()
|
||||||
|
scaled_tom6_test_x_tom = scaler.fit_transform(tom6_test_x_tom)
|
||||||
|
scaler = StandardScaler()
|
||||||
|
scaled_tom7_test_x_tom = scaler.fit_transform(tom7_test_x_tom)
|
||||||
|
scaler = StandardScaler()
|
||||||
|
scaled_tom8_test_x_tom = scaler.fit_transform(tom8_test_x_tom)
|
||||||
|
scaler = StandardScaler()
|
||||||
|
scaled_test_x_cpa = scaler.fit_transform(test_x_cpa)
|
||||||
|
sim6 = [cosine_similarity(t, c) for t, c in zip(scaled_tom6_test_x_tom, scaled_test_x_cpa)]
|
||||||
|
sim7 = [cosine_similarity(t, c) for t, c in zip(scaled_tom7_test_x_tom, scaled_test_x_cpa)]
|
||||||
|
sim8 = [cosine_similarity(t, c) for t, c in zip(scaled_tom8_test_x_tom, scaled_test_x_cpa)]
|
||||||
|
print(f"[tom6] max sim: {np.max(sim6)}, mean sim: {np.mean(sim6)}")
|
||||||
|
print(f"[tom7] max sim: {np.max(sim7)}, mean sim: {np.mean(sim7)}")
|
||||||
|
print(f"[tom8] max sim: {np.max(sim8)}, mean sim: {np.mean(sim8)}")
|
||||||
|
breakpoint()
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError
|
416
plan_predictor.py
Normal file
416
plan_predictor.py
Normal file
|
@ -0,0 +1,416 @@
|
||||||
|
from glob import glob
|
||||||
|
import os, json, sys
|
||||||
|
import torch, random, torch.nn as nn, numpy as np
|
||||||
|
from torch import optim
|
||||||
|
from random import shuffle
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
|
from src.data.game_parser import GameParser, make_splits, onehot, DEVICE, set_seed
|
||||||
|
from src.models.plan_model import Model
|
||||||
|
from src.models.losses import PlanLoss
|
||||||
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def print_epoch(data,acc_loss,lst,exp, incremental=False):
|
||||||
|
print(f'{acc_loss:9.4f}',end='; ',flush=True)
|
||||||
|
acc = []
|
||||||
|
prec = []
|
||||||
|
rec = []
|
||||||
|
f1 = []
|
||||||
|
iou = []
|
||||||
|
total = []
|
||||||
|
predicts = []
|
||||||
|
targets = []
|
||||||
|
# for x,game in zip(data,lst):
|
||||||
|
for x in data:
|
||||||
|
|
||||||
|
game = lst[x[2]]
|
||||||
|
game_mats = game.plan['materials']
|
||||||
|
pov_plan = game.plan[f'player{game.pov}']
|
||||||
|
pov_plan_mat = game.__dict__[f'player{game.pov}_plan_mat']
|
||||||
|
possible_mats = [game.materials_dict[x]-1 for x,_ in zip(game_mats[1:],pov_plan[1:])]
|
||||||
|
possible_cand = [game.materials_dict[x]-1 for x,y in zip(game_mats[1:],pov_plan[1:]) if y['make'] and y['make'][0][0]==-1]
|
||||||
|
possible_extra = [game.materials_dict[x]-1 for x,y in zip(game_mats[1:],pov_plan[1:]) if y['make'] and y['make'][0][0]>-1]
|
||||||
|
a, b = x[:2]
|
||||||
|
if exp == 3:
|
||||||
|
a = a.reshape(21,21)
|
||||||
|
for idx,aa in enumerate(a):
|
||||||
|
if idx in possible_extra:
|
||||||
|
cand_idxs = set([i for i,x in enumerate(pov_plan_mat[idx]) if x])
|
||||||
|
th, _ = zip(*sorted([(i, x) for i, x in enumerate(aa) if i in possible_mats], key=lambda x:x[1])[-2:])
|
||||||
|
if len(cand_idxs.intersection(set(th))):
|
||||||
|
for jdx, _ in enumerate(aa):
|
||||||
|
a[idx,jdx] = pov_plan_mat[idx,jdx]
|
||||||
|
else:
|
||||||
|
for jdx, _ in enumerate(aa):
|
||||||
|
a[idx,jdx] = 0
|
||||||
|
else:
|
||||||
|
for jdx, aaa in enumerate(aa):
|
||||||
|
a[idx,jdx] = 0
|
||||||
|
elif exp == 2:
|
||||||
|
a = a.reshape(21,21)
|
||||||
|
for idx,aa in enumerate(a):
|
||||||
|
if idx in possible_cand:
|
||||||
|
th = [x for i, x in enumerate(aa) if i in possible_mats]
|
||||||
|
th = sorted(th)
|
||||||
|
th = th[-2]
|
||||||
|
th = 1.1 if th < (1/21) else th
|
||||||
|
for jdx, aaa in enumerate(aa):
|
||||||
|
if idx in possible_mats:
|
||||||
|
a[idx,jdx] = 0 if aaa < th else 1
|
||||||
|
else:
|
||||||
|
a[idx,jdx] = 0
|
||||||
|
else:
|
||||||
|
for jdx, aaa in enumerate(aa):
|
||||||
|
a[idx,jdx] = 0
|
||||||
|
|
||||||
|
else:
|
||||||
|
a = a.reshape(21,21)
|
||||||
|
for idx,aa in enumerate(a):
|
||||||
|
th = sorted(aa)[-2]
|
||||||
|
th = 1.1 if th < (2.1/21) else th
|
||||||
|
for jdx, aaa in enumerate(aa):
|
||||||
|
a[idx,jdx] = 0 if aaa < th else 1
|
||||||
|
a = a.reshape(-1)
|
||||||
|
predicts.append(np.argmax(a))
|
||||||
|
targets.append(np.argmax(a) if np.argmax(a) in [x for x in b if x] else np.argmax(b))
|
||||||
|
acc.append(accuracy_score(a,b))
|
||||||
|
sa = set([i for i,x in enumerate(a) if x])
|
||||||
|
sb = set([i for i,x in enumerate(b) if x])
|
||||||
|
i = len(sa.intersection(sb))
|
||||||
|
u = len(sa.union(sb))
|
||||||
|
if u > 0:
|
||||||
|
a,b = zip(*[(x,y) for x,y in zip(a,b) if x+y > 0])
|
||||||
|
f1.append(f1_score(b,a,zero_division=1))
|
||||||
|
prec.append(precision_score(b,a,zero_division=1))
|
||||||
|
rec.append(recall_score(b,a,zero_division=1))
|
||||||
|
iou.append(i/u if u > 0 else 1)
|
||||||
|
total.append(sum(a))
|
||||||
|
print(
|
||||||
|
# f'({accuracy_score(targets,predicts):5.3f},'
|
||||||
|
# f'{np.mean(acc):5.3f},'
|
||||||
|
# f'{np.mean(prec):5.3f},'
|
||||||
|
# f'{np.mean(rec):5.3f},'
|
||||||
|
f'{np.mean(f1):5.3f},'
|
||||||
|
f'{np.mean(iou):5.3f},'
|
||||||
|
f'{np.std(iou):5.3f},',
|
||||||
|
# f'{np.mean(total):5.3f})',
|
||||||
|
end=' ',flush=True)
|
||||||
|
print('', end='; ',flush=True)
|
||||||
|
return accuracy_score(targets,predicts), np.mean(acc), np.mean(f1), np.mean(iou)
|
||||||
|
|
||||||
|
def do_split(model,lst,exp,criterion,optimizer=None,global_plan=False, player_plan=False, incremental=False, device=DEVICE):
|
||||||
|
data = []
|
||||||
|
acc_loss = 0
|
||||||
|
p = []
|
||||||
|
g = []
|
||||||
|
masks = []
|
||||||
|
for batch, game in enumerate(lst):
|
||||||
|
|
||||||
|
if model.training and (not optimizer is None): optimizer.zero_grad()
|
||||||
|
|
||||||
|
if exp==0:
|
||||||
|
ground_truth = torch.tensor(game.global_plan_mat.reshape(-1)).float()
|
||||||
|
elif exp==1:
|
||||||
|
ground_truth = torch.tensor(game.partner_plan.reshape(-1)).float()
|
||||||
|
elif exp==2:
|
||||||
|
ground_truth = torch.tensor(game.global_diff_plan_mat.reshape(-1)).float()
|
||||||
|
loss_mask = torch.tensor(game.global_plan_mat.reshape(-1)).float()
|
||||||
|
else:
|
||||||
|
ground_truth = torch.tensor(game.partner_diff_plan_mat.reshape(-1)).float()
|
||||||
|
loss_mask = torch.tensor(game.plan_repr.reshape(-1)).float()
|
||||||
|
|
||||||
|
prediction, _ = model(game, global_plan=global_plan, player_plan=player_plan, incremental=incremental)
|
||||||
|
|
||||||
|
if incremental:
|
||||||
|
ground_truth = ground_truth.to(device)
|
||||||
|
g += [ground_truth for _ in prediction]
|
||||||
|
masks += [loss_mask for _ in prediction]
|
||||||
|
|
||||||
|
p += [x for x in prediction]
|
||||||
|
|
||||||
|
data += list(zip(prediction.cpu().data.numpy(), [ground_truth.cpu().data.numpy()]*len(prediction),[batch]*len(prediction)))
|
||||||
|
else:
|
||||||
|
ground_truth = ground_truth.to(device)
|
||||||
|
g.append(ground_truth)
|
||||||
|
masks.append(loss_mask)
|
||||||
|
|
||||||
|
p.append(prediction)
|
||||||
|
|
||||||
|
data.append((prediction.cpu().data.numpy(), ground_truth.cpu().data.numpy(),batch))
|
||||||
|
|
||||||
|
if (batch+1) % 2 == 0:
|
||||||
|
loss = criterion(torch.stack(p),torch.stack(g), torch.stack(masks))
|
||||||
|
|
||||||
|
loss += 1e-5 * sum(p.pow(2.0).sum() for p in model.parameters())
|
||||||
|
if model.training and (not optimizer is None):
|
||||||
|
loss.backward()
|
||||||
|
# nn.utils.clip_grad_norm_(model.parameters(), 1)
|
||||||
|
# nn.utils.clip_grad_norm_(model.parameters(), 10)
|
||||||
|
optimizer.step()
|
||||||
|
acc_loss += loss.item()
|
||||||
|
p = []
|
||||||
|
g = []
|
||||||
|
masks = []
|
||||||
|
|
||||||
|
acc_loss /= len(lst)
|
||||||
|
|
||||||
|
acc0, acc, f1, iou = print_epoch(data,acc_loss,lst,exp)
|
||||||
|
|
||||||
|
return acc0, acc_loss, data, acc, f1, iou
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
torch.nn.init.xavier_uniform_(m.weight)
|
||||||
|
m.bias.data.fill_(0.01)
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args, flush=True)
|
||||||
|
print(f'PID: {os.getpid():6d}', flush=True)
|
||||||
|
|
||||||
|
if isinstance(args.device, int) and args.device >= 0:
|
||||||
|
DEVICE = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
|
||||||
|
print(f'Using {DEVICE}')
|
||||||
|
else:
|
||||||
|
print('Device must be a zero or positive integer, but got',args.device)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# if args.seed=='Random':
|
||||||
|
# pass
|
||||||
|
# elif args.seed=='Fixed':
|
||||||
|
# random.seed(0)
|
||||||
|
# torch.manual_seed(1)
|
||||||
|
# np.random.seed(0)
|
||||||
|
# else:
|
||||||
|
# print('Seed must be in [Random, Fixed], but got',args.seed)
|
||||||
|
# exit()
|
||||||
|
|
||||||
|
if isinstance(args.seed, int) and args.seed >= 0:
|
||||||
|
seed = set_seed(args.seed)
|
||||||
|
else:
|
||||||
|
print('Seed must be a zero or positive integer, but got',args.seed)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# dataset_splits = make_splits('config/dataset_splits.json')
|
||||||
|
# dataset_splits = make_splits('config/dataset_splits_dev.json')
|
||||||
|
# dataset_splits = make_splits('config/dataset_splits_old.json')
|
||||||
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
||||||
|
|
||||||
|
if args.use_dialogue=='Yes':
|
||||||
|
d_flag = True
|
||||||
|
elif args.use_dialogue=='No':
|
||||||
|
d_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.use_dialogue_moves=='Yes':
|
||||||
|
d_move_flag = True
|
||||||
|
elif args.use_dialogue_moves=='No':
|
||||||
|
d_move_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not args.experiment in list(range(9)):
|
||||||
|
print('Experiment must be in',list(range(9)),', but got',args.experiment)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not args.intermediate in list(range(32)):
|
||||||
|
print('Intermediate must be in',list(range(32)),', but got',args.intermediate)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
if args.seq_model=='GRU':
|
||||||
|
seq_model = 0
|
||||||
|
elif args.seq_model=='LSTM':
|
||||||
|
seq_model = 1
|
||||||
|
elif args.seq_model=='Transformer':
|
||||||
|
seq_model = 2
|
||||||
|
else:
|
||||||
|
print('The sequence model must be in [GRU, LSTM, Transformer], but got', args.seq_model)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.plans=='Yes':
|
||||||
|
global_plan = (args.pov=='Third') or ((args.pov=='None') and (args.experiment in list(range(3))))
|
||||||
|
player_plan = (args.pov=='First') or ((args.pov=='None') and (args.experiment in list(range(3,9))))
|
||||||
|
elif args.plans=='No' or args.plans is None:
|
||||||
|
global_plan = False
|
||||||
|
player_plan = False
|
||||||
|
else:
|
||||||
|
print('Use Plan must be in [Yes, No], but got',args.plan)
|
||||||
|
exit()
|
||||||
|
print('global_plan', global_plan, 'player_plan', player_plan)
|
||||||
|
|
||||||
|
if args.pov=='None':
|
||||||
|
val = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
val += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
val = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
val = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
val += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
model = Model(seq_model,DEVICE).to(DEVICE)
|
||||||
|
model.apply(init_weights)
|
||||||
|
|
||||||
|
print(model)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
learning_rate = 1e-5
|
||||||
|
weight_decay=1e-4
|
||||||
|
|
||||||
|
# optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
# optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
|
||||||
|
# optimizer = optim.Adagrad(model.parameters(), lr=learning_rate)
|
||||||
|
# optimizer = optim.Adadelta(model.parameters())
|
||||||
|
# optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay)
|
||||||
|
# criterion = nn.CrossEntropyLoss()
|
||||||
|
# criterion = nn.BCELoss()
|
||||||
|
criterion = PlanLoss()
|
||||||
|
# criterion = torch.hub.load(
|
||||||
|
# 'adeelh/pytorch-multi-class-focal-loss',
|
||||||
|
# model='focal_loss',
|
||||||
|
# alpha=[.25, .75],
|
||||||
|
# gamma=10,
|
||||||
|
# reduction='mean',
|
||||||
|
# device=device,
|
||||||
|
# dtype=torch.float32,
|
||||||
|
# force_reload=False
|
||||||
|
# )
|
||||||
|
# criterion = nn.BCEWithLogitsLoss(pos_weight=10*torch.ones(21*21).to(device))
|
||||||
|
# criterion = nn.MSELoss()
|
||||||
|
|
||||||
|
print(str(criterion), str(optimizer))
|
||||||
|
|
||||||
|
num_epochs = 200#1#
|
||||||
|
min_acc_loss = 1e6
|
||||||
|
max_f1 = 0
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
wait_epoch = 15#150#1000#
|
||||||
|
max_fails = 5
|
||||||
|
|
||||||
|
if args.model_path is not None:
|
||||||
|
print(f'Loading {args.model_path}')
|
||||||
|
model.load_state_dict(torch.load(args.model_path))
|
||||||
|
model.eval()
|
||||||
|
acc, acc_loss, data, _, f1, iou = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, incremental=True, device=DEVICE)
|
||||||
|
acc, acc_loss0, data, _, f1, iou = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, incremental=False, device=DEVICE)
|
||||||
|
|
||||||
|
if np.mean([acc_loss,acc_loss0]) < min_acc_loss:
|
||||||
|
min_acc_loss = np.mean([acc_loss,acc_loss0])
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
print('^')
|
||||||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
|
||||||
|
# data = list(zip(*data))
|
||||||
|
# for x in data:
|
||||||
|
# a, b = list(zip(*x))
|
||||||
|
# f1 = f1_score(a,b,average='weighted')
|
||||||
|
# f1 = f1_score(a,b,average='weighted')
|
||||||
|
# if (max_f1 < f1):
|
||||||
|
# max_f1 = f1
|
||||||
|
# epochs_since_improvement = 0
|
||||||
|
# print('^')
|
||||||
|
# torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
# model = model.to(DEVICE)
|
||||||
|
else:
|
||||||
|
print('Training model from scratch', flush=True)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
print(f'{os.getpid():6d} {epoch+1:4d},',end=' ',flush=True)
|
||||||
|
shuffle(train)
|
||||||
|
model.train()
|
||||||
|
do_split(model,train,args.experiment,criterion,optimizer=optimizer,global_plan=global_plan, player_plan=player_plan, incremental=True, device=DEVICE)
|
||||||
|
do_split(model,train,args.experiment,criterion,optimizer=optimizer,global_plan=global_plan, player_plan=player_plan, incremental=False, device=DEVICE)
|
||||||
|
model.eval()
|
||||||
|
acc, acc_loss, data, _, f1, iou = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, incremental=True, device=DEVICE)
|
||||||
|
acc, acc_loss0, data, _, f1, iou = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, incremental=False, device=DEVICE)
|
||||||
|
|
||||||
|
if np.mean([acc_loss,acc_loss0]) < min_acc_loss:
|
||||||
|
min_acc_loss = np.mean([acc_loss,acc_loss0])
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
print('^')
|
||||||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
else:
|
||||||
|
epochs_since_improvement += 1
|
||||||
|
print()
|
||||||
|
|
||||||
|
# test_val = iou
|
||||||
|
# if (max_f1 < test_val):
|
||||||
|
# max_f1 = test_val
|
||||||
|
# epochs_since_improvement = 0
|
||||||
|
# print('^')
|
||||||
|
# if not args.save_path is None:
|
||||||
|
# torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
# model = model.to(DEVICE)
|
||||||
|
# else:
|
||||||
|
# epochs_since_improvement += 1
|
||||||
|
# print()
|
||||||
|
|
||||||
|
if epoch > wait_epoch and epochs_since_improvement > max_fails:
|
||||||
|
break
|
||||||
|
print()
|
||||||
|
print('Test')
|
||||||
|
model.load_state_dict(torch.load(args.save_path))
|
||||||
|
|
||||||
|
val = None
|
||||||
|
train = None
|
||||||
|
if args.pov=='None':
|
||||||
|
test = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
test += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
test = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
test = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
test += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
model.eval()
|
||||||
|
acc, acc_loss, data, _, f1, iou = do_split(model,test,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, incremental=True, device=DEVICE)
|
||||||
|
acc, acc_loss, data, _, f1, iou = do_split(model,test,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, incremental=False, device=DEVICE)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(data)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||||
|
parser.add_argument('--pov', type=str,
|
||||||
|
help='point of view [None, First, Third]')
|
||||||
|
parser.add_argument('--use_dialogue', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--use_dialogue_moves', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--plans', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--seq_model', type=str,
|
||||||
|
help='point of view [GRU, LSTM, Transformer]')
|
||||||
|
parser.add_argument('--experiment', type=int,
|
||||||
|
help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]')
|
||||||
|
parser.add_argument('--intermediate', type=int,
|
||||||
|
help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]')
|
||||||
|
parser.add_argument('--save_path', type=str,
|
||||||
|
help='path where to save model')
|
||||||
|
parser.add_argument('--seed', type=int,
|
||||||
|
help='Selet random seed by index [0, 1, 2, ...]. 0 -> random seed set to 0. n>0 -> random seed '
|
||||||
|
'set to n\'th random number with original seed set to 0')
|
||||||
|
parser.add_argument('--device', type=int, default=0,
|
||||||
|
help='select cuda device number')
|
||||||
|
parser.add_argument('--model_path', type=str, default=None,
|
||||||
|
help='path to the pretrained model to be loaded')
|
||||||
|
|
||||||
|
main(parser.parse_args())
|
325
plan_predictor_graphs.py
Normal file
325
plan_predictor_graphs.py
Normal file
|
@ -0,0 +1,325 @@
|
||||||
|
import os
|
||||||
|
import torch, torch.nn as nn, numpy as np
|
||||||
|
from torch import optim
|
||||||
|
from random import shuffle
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score
|
||||||
|
from src.data.game_parser_graphs_new import GameParser, make_splits, set_seed
|
||||||
|
from src.models.plan_model_graphs import Model
|
||||||
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
|
def print_epoch(data, acc_loss):
|
||||||
|
print(f'{acc_loss:9.4f}',end='; ',flush=True)
|
||||||
|
acc = []
|
||||||
|
f1 = []
|
||||||
|
for x in data:
|
||||||
|
a, b, _, _, _, _, _, _ = x
|
||||||
|
acc.append(accuracy_score(b, a))
|
||||||
|
f1.append(f1_score(b, a, zero_division=1))
|
||||||
|
print(f'{np.mean(f1):5.3f},', end=' ', flush=True)
|
||||||
|
print('', end='; ', flush=True)
|
||||||
|
return np.mean(acc), np.mean(f1), f1
|
||||||
|
|
||||||
|
def do_split(model, lst, exp, criterion, device, optimizer=None, global_plan=False, player_plan=False, incremental=False):
|
||||||
|
data = []
|
||||||
|
acc_loss = 0
|
||||||
|
for batch, game in enumerate(lst):
|
||||||
|
if (exp != 2) and (exp != 3):
|
||||||
|
raise ValueError('This script is only for exp == 2 or exp == 3.')
|
||||||
|
prediction, ground_truth, sel = model(game, experiment=exp, global_plan=global_plan, player_plan=player_plan, incremental=incremental)
|
||||||
|
if exp == 2:
|
||||||
|
if sel[0]:
|
||||||
|
prediction = prediction[game.player1_plan.edge_index.shape[1]:]
|
||||||
|
ground_truth = ground_truth[game.player1_plan.edge_index.shape[1]:]
|
||||||
|
if sel[1]:
|
||||||
|
prediction = prediction[game.player2_plan.edge_index.shape[1]:]
|
||||||
|
ground_truth = ground_truth[game.player2_plan.edge_index.shape[1]:]
|
||||||
|
if prediction.numel() == 0 and ground_truth.numel() == 0: continue
|
||||||
|
if incremental:
|
||||||
|
ground_truth = ground_truth.to(device).repeat(prediction.shape[0], 1)
|
||||||
|
data += list(zip(torch.round(torch.sigmoid(prediction)).float().cpu().data.numpy(),
|
||||||
|
ground_truth.cpu().data.numpy(),
|
||||||
|
[game.player1_plan.edge_index.shape[1]]*len(prediction),
|
||||||
|
[game.player2_plan.edge_index.shape[1]]*len(prediction),
|
||||||
|
[game.global_plan.edge_index.shape[1]]*len(prediction),
|
||||||
|
[sel]*len(prediction),
|
||||||
|
[game.game_path]*len(prediction),
|
||||||
|
[batch]*len(prediction)))
|
||||||
|
else:
|
||||||
|
ground_truth = ground_truth.to(device)
|
||||||
|
data.append((
|
||||||
|
torch.round(torch.sigmoid(prediction)).float().cpu().data.numpy(),
|
||||||
|
ground_truth.cpu().data.numpy(),
|
||||||
|
game.player1_plan.edge_index.shape[1],
|
||||||
|
game.player2_plan.edge_index.shape[1],
|
||||||
|
game.global_plan.edge_index.shape[1],
|
||||||
|
sel,
|
||||||
|
game.game_path,
|
||||||
|
batch,
|
||||||
|
))
|
||||||
|
loss = criterion(prediction, ground_truth)
|
||||||
|
# loss += 1e-5 * sum(p.pow(2.0).sum() for p in model.parameters())
|
||||||
|
acc_loss += loss.item()
|
||||||
|
if model.training and (not optimizer is None):
|
||||||
|
loss.backward()
|
||||||
|
if (batch+1) % 2 == 0: # gradient accumulation
|
||||||
|
# nn.utils.clip_grad_norm_(model.parameters(), 1)
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
acc_loss /= len(lst)
|
||||||
|
acc, f1, f1_list = print_epoch(data, acc_loss)
|
||||||
|
if not incremental:
|
||||||
|
data = [data[i] + (f1_list[i],) for i in range(len(data))]
|
||||||
|
return acc_loss, data, acc, f1
|
||||||
|
|
||||||
|
def init_weights(m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
torch.nn.init.xavier_uniform_(m.weight)
|
||||||
|
m.bias.data.fill_(0.00)
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args, flush=True)
|
||||||
|
print(f'PID: {os.getpid():6d}', flush=True)
|
||||||
|
|
||||||
|
if isinstance(args.device, int) and args.device >= 0:
|
||||||
|
DEVICE = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
|
||||||
|
print(f'Using {DEVICE}')
|
||||||
|
else:
|
||||||
|
print('Device must be a zero or positive integer, but got',args.device)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if isinstance(args.seed, int) and args.seed >= 0:
|
||||||
|
seed = set_seed(args.seed)
|
||||||
|
else:
|
||||||
|
print('Seed must be a zero or positive integer, but got',args.seed)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
||||||
|
# dataset_splits = make_splits('config/dataset_splits_dev.json')
|
||||||
|
|
||||||
|
if args.use_dialogue=='Yes':
|
||||||
|
d_flag = True
|
||||||
|
elif args.use_dialogue=='No':
|
||||||
|
d_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.use_dialogue_moves=='Yes':
|
||||||
|
d_move_flag = True
|
||||||
|
elif args.use_dialogue_moves=='No':
|
||||||
|
d_move_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not args.experiment in list(range(9)):
|
||||||
|
print('Experiment must be in',list(range(9)),', but got',args.experiment)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not args.intermediate in list(range(32)):
|
||||||
|
print('Intermediate must be in',list(range(32)),', but got',args.intermediate)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.seq_model=='GRU':
|
||||||
|
seq_model = 0
|
||||||
|
elif args.seq_model=='LSTM':
|
||||||
|
seq_model = 1
|
||||||
|
elif args.seq_model=='Transformer':
|
||||||
|
seq_model = 2
|
||||||
|
else:
|
||||||
|
print('The sequence model must be in [GRU, LSTM, Transformer], but got', args.seq_model)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.plans=='Yes':
|
||||||
|
global_plan = (args.pov=='Third') or ((args.pov=='None') and (args.experiment in list(range(3))))
|
||||||
|
player_plan = (args.pov=='First') or ((args.pov=='None') and (args.experiment in list(range(3,9))))
|
||||||
|
elif args.plans=='No' or args.plans is None:
|
||||||
|
global_plan = False
|
||||||
|
player_plan = False
|
||||||
|
else:
|
||||||
|
print('Use Plan must be in [Yes, No], but got',args.plan)
|
||||||
|
exit()
|
||||||
|
print('global_plan', global_plan, 'player_plan', player_plan)
|
||||||
|
|
||||||
|
if args.use_int0_instead_of_intermediate:
|
||||||
|
if args.pov=='None':
|
||||||
|
val = [GameParser(f,d_flag,0,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,0,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['training'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
val += [GameParser(f,d_flag,4,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train += [GameParser(f,d_flag,4,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
val = [GameParser(f,d_flag,3,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,3,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
val = [GameParser(f,d_flag,1,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,1,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['training'])]
|
||||||
|
val += [GameParser(f,d_flag,2,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train += [GameParser(f,d_flag,2,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['training'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
exit()
|
||||||
|
else:
|
||||||
|
if args.pov=='None':
|
||||||
|
val = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
val += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
val = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
val = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
val += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
model = Model(seq_model, DEVICE).to(DEVICE)
|
||||||
|
# model.apply(init_weights)
|
||||||
|
|
||||||
|
print(model)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
learning_rate = args.lr
|
||||||
|
weight_decay = args.weight_decay
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
if args.experiment == 2:
|
||||||
|
pos_weight = torch.tensor([2.5], device=DEVICE)
|
||||||
|
if args.experiment == 3:
|
||||||
|
pos_weight = torch.tensor([10.0], device=DEVICE)
|
||||||
|
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
|
||||||
|
# criterion = nn.BCEWithLogitsLoss()
|
||||||
|
print(str(criterion), str(optimizer))
|
||||||
|
|
||||||
|
num_epochs = 200
|
||||||
|
min_acc_loss = 1e6
|
||||||
|
best_f1 = 0.0
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
wait_epoch = 15
|
||||||
|
max_fails = 5
|
||||||
|
|
||||||
|
if args.model_path is not None:
|
||||||
|
print(f'Loading {args.model_path}')
|
||||||
|
model.load_state_dict(torch.load(args.model_path))
|
||||||
|
model.eval()
|
||||||
|
# acc_loss, data, acc, f1 = do_split(model, val, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=True)
|
||||||
|
acc_loss0, data, acc, f1 = do_split(model, val, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False)
|
||||||
|
# if np.mean([acc_loss, acc_loss0]) < min_acc_loss:
|
||||||
|
if f1 > best_f1:
|
||||||
|
# min_acc_loss = np.mean([acc_loss, acc_loss0])
|
||||||
|
best_f1 = f1
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
print('^')
|
||||||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
else:
|
||||||
|
print('Training model from scratch', flush=True)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
print(f'{os.getpid():6d} {epoch+1:4d},',end=' ',flush=True)
|
||||||
|
shuffle(train)
|
||||||
|
model.train()
|
||||||
|
# do_split(model, train, args.experiment, criterion, device=DEVICE, optimizer=optimizer, global_plan=global_plan, player_plan=player_plan, incremental=True)
|
||||||
|
do_split(model, train, args.experiment, criterion, device=DEVICE, optimizer=optimizer, global_plan=global_plan, player_plan=player_plan, incremental=False)
|
||||||
|
model.eval()
|
||||||
|
# acc_loss, data, acc, f1 = do_split(model, val, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=True)
|
||||||
|
acc_loss0, data, acc, f1 = do_split(model, val, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False)
|
||||||
|
|
||||||
|
# if np.mean([acc_loss, acc_loss0]) < min_acc_loss:
|
||||||
|
# if acc_loss0 < min_acc_loss:
|
||||||
|
if f1 > best_f1:
|
||||||
|
# min_acc_loss = np.mean([acc_loss, acc_loss0])
|
||||||
|
# min_acc_loss = acc_loss0
|
||||||
|
best_f1 = f1
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
print('^')
|
||||||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
else:
|
||||||
|
epochs_since_improvement += 1
|
||||||
|
print()
|
||||||
|
|
||||||
|
if epoch > wait_epoch and epochs_since_improvement > max_fails:
|
||||||
|
break
|
||||||
|
print()
|
||||||
|
print('Test')
|
||||||
|
model.load_state_dict(torch.load(args.save_path))
|
||||||
|
|
||||||
|
if args.use_int0_instead_of_intermediate:
|
||||||
|
val = None
|
||||||
|
train = None
|
||||||
|
if args.pov=='None':
|
||||||
|
test = [GameParser(f,d_flag,0,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['test'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
test += [GameParser(f,d_flag,4,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
test = [GameParser(f,d_flag,3,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
test = [GameParser(f,d_flag,1,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['test'])]
|
||||||
|
test += [GameParser(f,d_flag,2,args.intermediate,d_move_flag,load_int0_feats=True) for f in tqdm(dataset_splits['test'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
else:
|
||||||
|
val = None
|
||||||
|
train = None
|
||||||
|
if args.pov=='None':
|
||||||
|
test = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
test += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
test = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
test = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
test += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
# acc_loss, data, acc, f1 = do_split(model, test, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=True)
|
||||||
|
acc_loss, data, acc, f1 = do_split(model, test, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(data)
|
||||||
|
print()
|
||||||
|
|
||||||
|
with open(f'{args.save_path[:-6]}_data.pkl', 'wb') as f:
|
||||||
|
pickle.dump(data, f)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||||
|
parser.add_argument('--pov', type=str,
|
||||||
|
help='point of view [None, First, Third]')
|
||||||
|
parser.add_argument('--use_dialogue', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--use_dialogue_moves', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--plans', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--seq_model', type=str,
|
||||||
|
help='point of view [GRU, LSTM, Transformer]')
|
||||||
|
parser.add_argument('--experiment', type=int,
|
||||||
|
help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]')
|
||||||
|
parser.add_argument('--intermediate', type=int,
|
||||||
|
help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]')
|
||||||
|
parser.add_argument('--save_path', type=str,
|
||||||
|
help='path where to save model')
|
||||||
|
parser.add_argument('--seed', type=int,
|
||||||
|
help='Selet random seed by index [0, 1, 2, ...]. 0 -> random seed set to 0. n>0 -> random seed '
|
||||||
|
'set to n\'th random number with original seed set to 0')
|
||||||
|
parser.add_argument('--device', type=int, default=0,
|
||||||
|
help='select cuda device number')
|
||||||
|
parser.add_argument('--model_path', type=str, default=None,
|
||||||
|
help='path to the pretrained model to be loaded')
|
||||||
|
parser.add_argument('--weight_decay', type=float, default=0.0)
|
||||||
|
parser.add_argument('--lr', type=float, default=1e-4)
|
||||||
|
parser.add_argument('--use_int0_instead_of_intermediate', action='store_true')
|
||||||
|
|
||||||
|
|
||||||
|
main(parser.parse_args())
|
283
plan_predictor_graphs_oracle.py
Normal file
283
plan_predictor_graphs_oracle.py
Normal file
|
@ -0,0 +1,283 @@
|
||||||
|
import os
|
||||||
|
import torch, torch.nn as nn, numpy as np
|
||||||
|
from torch import optim
|
||||||
|
from random import shuffle
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score
|
||||||
|
from src.data.game_parser_graphs_new import GameParser, make_splits, set_seed
|
||||||
|
from src.models.plan_model_graphs_oracle import Model
|
||||||
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
|
def print_epoch(data, acc_loss):
|
||||||
|
print(f'{acc_loss:9.4f}',end='; ',flush=True)
|
||||||
|
acc = []
|
||||||
|
f1 = []
|
||||||
|
for x in data:
|
||||||
|
a, b, _, _, _, _, _, _ = x
|
||||||
|
acc.append(accuracy_score(b, a))
|
||||||
|
f1.append(f1_score(b, a, zero_division=1))
|
||||||
|
print(f'{np.mean(f1):5.3f},', end=' ', flush=True)
|
||||||
|
print('', end='; ', flush=True)
|
||||||
|
return np.mean(acc), np.mean(f1), f1
|
||||||
|
|
||||||
|
def do_split(model, lst, exp, criterion, device, optimizer=None, global_plan=False, player_plan=False, incremental=False, intermediate=0):
|
||||||
|
data = []
|
||||||
|
acc_loss = 0
|
||||||
|
for batch, game in enumerate(lst):
|
||||||
|
if (exp != 2) and (exp != 3):
|
||||||
|
raise ValueError('This script is only for exp == 2 or exp == 3.')
|
||||||
|
prediction, ground_truth, sel = model(game, experiment=exp, global_plan=global_plan, player_plan=player_plan, incremental=incremental, intermediate=intermediate)
|
||||||
|
if exp == 2:
|
||||||
|
if sel[0]:
|
||||||
|
prediction = prediction[game.player1_plan.edge_index.shape[1]:]
|
||||||
|
ground_truth = ground_truth[game.player1_plan.edge_index.shape[1]:]
|
||||||
|
if sel[1]:
|
||||||
|
prediction = prediction[game.player2_plan.edge_index.shape[1]:]
|
||||||
|
ground_truth = ground_truth[game.player2_plan.edge_index.shape[1]:]
|
||||||
|
if prediction.numel() == 0 and ground_truth.numel() == 0: continue
|
||||||
|
if incremental:
|
||||||
|
ground_truth = ground_truth.to(device).repeat(prediction.shape[0], 1)
|
||||||
|
data += list(zip(torch.round(torch.sigmoid(prediction)).float().cpu().data.numpy(),
|
||||||
|
ground_truth.cpu().data.numpy(),
|
||||||
|
[game.player1_plan.edge_index.shape[1]]*len(prediction),
|
||||||
|
[game.player2_plan.edge_index.shape[1]]*len(prediction),
|
||||||
|
[game.global_plan.edge_index.shape[1]]*len(prediction),
|
||||||
|
[sel]*len(prediction),
|
||||||
|
[game.game_path]*len(prediction),
|
||||||
|
[batch]*len(prediction)))
|
||||||
|
else:
|
||||||
|
ground_truth = ground_truth.to(device)
|
||||||
|
data.append((
|
||||||
|
torch.round(torch.sigmoid(prediction)).float().cpu().data.numpy(),
|
||||||
|
ground_truth.cpu().data.numpy(),
|
||||||
|
game.player1_plan.edge_index.shape[1],
|
||||||
|
game.player2_plan.edge_index.shape[1],
|
||||||
|
game.global_plan.edge_index.shape[1],
|
||||||
|
sel,
|
||||||
|
game.game_path,
|
||||||
|
batch,
|
||||||
|
))
|
||||||
|
loss = criterion(prediction, ground_truth)
|
||||||
|
# loss += 1e-5 * sum(p.pow(2.0).sum() for p in model.parameters())
|
||||||
|
acc_loss += loss.item()
|
||||||
|
if model.training and (not optimizer is None):
|
||||||
|
loss.backward()
|
||||||
|
if (batch+1) % 2 == 0: # gradient accumulation
|
||||||
|
# nn.utils.clip_grad_norm_(model.parameters(), 1)
|
||||||
|
optimizer.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
acc_loss /= len(lst)
|
||||||
|
acc, f1, f1_list = print_epoch(data, acc_loss)
|
||||||
|
if not incremental:
|
||||||
|
data = [data[i] + (f1_list[i],) for i in range(len(data))]
|
||||||
|
return acc_loss, data, acc, f1
|
||||||
|
|
||||||
|
def init_weights(m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
torch.nn.init.xavier_uniform_(m.weight)
|
||||||
|
m.bias.data.fill_(0.00)
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args, flush=True)
|
||||||
|
print(f'PID: {os.getpid():6d}', flush=True)
|
||||||
|
|
||||||
|
if isinstance(args.device, int) and args.device >= 0:
|
||||||
|
DEVICE = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
|
||||||
|
print(f'Using {DEVICE}')
|
||||||
|
else:
|
||||||
|
print('Device must be a zero or positive integer, but got',args.device)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if isinstance(args.seed, int) and args.seed >= 0:
|
||||||
|
seed = set_seed(args.seed)
|
||||||
|
else:
|
||||||
|
print('Seed must be a zero or positive integer, but got',args.seed)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
||||||
|
# dataset_splits = make_splits('config/dataset_splits_dev.json')
|
||||||
|
|
||||||
|
if args.use_dialogue=='Yes':
|
||||||
|
d_flag = True
|
||||||
|
elif args.use_dialogue=='No':
|
||||||
|
d_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.use_dialogue_moves=='Yes':
|
||||||
|
d_move_flag = True
|
||||||
|
elif args.use_dialogue_moves=='No':
|
||||||
|
d_move_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not args.experiment in list(range(9)):
|
||||||
|
print('Experiment must be in',list(range(9)),', but got',args.experiment)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not args.intermediate in list(range(32)):
|
||||||
|
print('Intermediate must be in',list(range(32)),', but got',args.intermediate)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.seq_model=='GRU':
|
||||||
|
seq_model = 0
|
||||||
|
elif args.seq_model=='LSTM':
|
||||||
|
seq_model = 1
|
||||||
|
elif args.seq_model=='Transformer':
|
||||||
|
seq_model = 2
|
||||||
|
else:
|
||||||
|
print('The sequence model must be in [GRU, LSTM, Transformer], but got', args.seq_model)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.plans=='Yes':
|
||||||
|
global_plan = (args.pov=='Third') or ((args.pov=='None') and (args.experiment in list(range(3))))
|
||||||
|
player_plan = (args.pov=='First') or ((args.pov=='None') and (args.experiment in list(range(3,9))))
|
||||||
|
elif args.plans=='No' or args.plans is None:
|
||||||
|
global_plan = False
|
||||||
|
player_plan = False
|
||||||
|
else:
|
||||||
|
print('Use Plan must be in [Yes, No], but got',args.plan)
|
||||||
|
exit()
|
||||||
|
print('global_plan', global_plan, 'player_plan', player_plan)
|
||||||
|
|
||||||
|
if args.pov=='None':
|
||||||
|
val = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
val += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
val = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
val = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
val += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
model = Model(seq_model, DEVICE).to(DEVICE)
|
||||||
|
# model.apply(init_weights)
|
||||||
|
|
||||||
|
print(model)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
learning_rate = 1e-4
|
||||||
|
weight_decay = 0.0 #1e-4
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([2.5], device=DEVICE))
|
||||||
|
# criterion = nn.BCEWithLogitsLoss()
|
||||||
|
print(str(criterion), str(optimizer))
|
||||||
|
|
||||||
|
num_epochs = 200
|
||||||
|
min_acc_loss = 1e6
|
||||||
|
best_f1 = 0.0
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
wait_epoch = 15
|
||||||
|
max_fails = 5
|
||||||
|
|
||||||
|
if args.model_path is not None:
|
||||||
|
print(f'Loading {args.model_path}')
|
||||||
|
model.load_state_dict(torch.load(args.model_path))
|
||||||
|
model.eval()
|
||||||
|
# acc_loss, data, acc, f1 = do_split(model, val, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=True)
|
||||||
|
acc_loss0, data, acc, f1 = do_split(model, val, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False)
|
||||||
|
# if np.mean([acc_loss, acc_loss0]) < min_acc_loss:
|
||||||
|
if f1 > best_f1:
|
||||||
|
# min_acc_loss = np.mean([acc_loss, acc_loss0])
|
||||||
|
best_f1 = f1
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
print('^')
|
||||||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
else:
|
||||||
|
print('Training model from scratch', flush=True)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
print(f'{os.getpid():6d} {epoch+1:4d},',end=' ',flush=True)
|
||||||
|
shuffle(train)
|
||||||
|
model.train()
|
||||||
|
# do_split(model, train, args.experiment, criterion, device=DEVICE, optimizer=optimizer, global_plan=global_plan, player_plan=player_plan, incremental=True)
|
||||||
|
do_split(model, train, args.experiment, criterion, device=DEVICE, optimizer=optimizer, global_plan=global_plan, player_plan=player_plan, incremental=False, intermediate=args.intermediate)
|
||||||
|
model.eval()
|
||||||
|
# acc_loss, data, acc, f1 = do_split(model, val, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=True)
|
||||||
|
acc_loss0, data, acc, f1 = do_split(model, val, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False, intermediate=args.intermediate)
|
||||||
|
|
||||||
|
# if np.mean([acc_loss, acc_loss0]) < min_acc_loss:
|
||||||
|
# if acc_loss0 < min_acc_loss:
|
||||||
|
if f1 > best_f1:
|
||||||
|
# min_acc_loss = np.mean([acc_loss, acc_loss0])
|
||||||
|
# min_acc_loss = acc_loss0
|
||||||
|
best_f1 = f1
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
print('^')
|
||||||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
else:
|
||||||
|
epochs_since_improvement += 1
|
||||||
|
print()
|
||||||
|
|
||||||
|
if epoch > wait_epoch and epochs_since_improvement > max_fails:
|
||||||
|
break
|
||||||
|
print()
|
||||||
|
print('Test')
|
||||||
|
model.load_state_dict(torch.load(args.save_path))
|
||||||
|
|
||||||
|
val = None
|
||||||
|
train = None
|
||||||
|
if args.pov=='None':
|
||||||
|
test = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
test += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
test = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
test = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
test += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
# acc_loss, data, acc, f1 = do_split(model, test, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=True)
|
||||||
|
acc_loss, data, acc, f1 = do_split(model, test, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(data)
|
||||||
|
print()
|
||||||
|
|
||||||
|
with open(f'{args.save_path[:-6]}_data.pkl', 'wb') as f:
|
||||||
|
pickle.dump(data, f)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||||
|
parser.add_argument('--pov', type=str,
|
||||||
|
help='point of view [None, First, Third]')
|
||||||
|
parser.add_argument('--use_dialogue', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--use_dialogue_moves', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--plans', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--seq_model', type=str,
|
||||||
|
help='point of view [GRU, LSTM, Transformer]')
|
||||||
|
parser.add_argument('--experiment', type=int,
|
||||||
|
help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]')
|
||||||
|
parser.add_argument('--intermediate', type=int,
|
||||||
|
help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]')
|
||||||
|
parser.add_argument('--save_path', type=str,
|
||||||
|
help='path where to save model')
|
||||||
|
parser.add_argument('--seed', type=int,
|
||||||
|
help='Selet random seed by index [0, 1, 2, ...]. 0 -> random seed set to 0. n>0 -> random seed '
|
||||||
|
'set to n\'th random number with original seed set to 0')
|
||||||
|
parser.add_argument('--device', type=int, default=0,
|
||||||
|
help='select cuda device number')
|
||||||
|
parser.add_argument('--model_path', type=str, default=None,
|
||||||
|
help='path to the pretrained model to be loaded')
|
||||||
|
|
||||||
|
main(parser.parse_args())
|
201
plan_predictor_graphs_test.py
Normal file
201
plan_predictor_graphs_test.py
Normal file
|
@ -0,0 +1,201 @@
|
||||||
|
import os
|
||||||
|
import torch, torch.nn as nn, numpy as np
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score
|
||||||
|
from src.data.game_parser_graphs_new import GameParser, make_splits, set_seed
|
||||||
|
from src.models.plan_model_graphs import Model
|
||||||
|
import argparse
|
||||||
|
import pickle
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
|
def print_epoch(data, acc_loss):
|
||||||
|
print(f'{acc_loss:9.4f}',end='; ',flush=True)
|
||||||
|
acc = []
|
||||||
|
f1 = []
|
||||||
|
for x in data:
|
||||||
|
a, b, _, _, _, _, _, _ = x
|
||||||
|
acc.append(accuracy_score(b, a))
|
||||||
|
f1.append(f1_score(b, a, zero_division=1))
|
||||||
|
print(f'{np.mean(f1):5.3f},', end=' ', flush=True)
|
||||||
|
print('', end='; ', flush=True)
|
||||||
|
return np.mean(acc), np.mean(f1), f1
|
||||||
|
|
||||||
|
def do_split(model, lst, exp, criterion, device, optimizer=None, global_plan=False, player_plan=False, incremental=False):
|
||||||
|
data = []
|
||||||
|
seq2seq_feats = []
|
||||||
|
acc_loss = 0
|
||||||
|
for batch, game in enumerate(lst):
|
||||||
|
if (exp != 2) and (exp != 3):
|
||||||
|
raise ValueError('This script is only for exp == 2 or exp == 3.')
|
||||||
|
prediction, ground_truth, sel, feats = model(game, experiment=exp, global_plan=global_plan, player_plan=player_plan, incremental=incremental, return_feats=True)
|
||||||
|
seq2seq_feats.append([feats, game.game_path])
|
||||||
|
if exp == 2:
|
||||||
|
if sel[0]:
|
||||||
|
prediction = prediction[game.player1_plan.edge_index.shape[1]:]
|
||||||
|
ground_truth = ground_truth[game.player1_plan.edge_index.shape[1]:]
|
||||||
|
if sel[1]:
|
||||||
|
prediction = prediction[game.player2_plan.edge_index.shape[1]:]
|
||||||
|
ground_truth = ground_truth[game.player2_plan.edge_index.shape[1]:]
|
||||||
|
if prediction.numel() == 0 and ground_truth.numel() == 0: continue
|
||||||
|
if incremental:
|
||||||
|
ground_truth = ground_truth.to(device).repeat(prediction.shape[0], 1)
|
||||||
|
data += list(zip(torch.round(torch.sigmoid(prediction)).float().cpu().data.numpy(),
|
||||||
|
ground_truth.cpu().data.numpy(),
|
||||||
|
[game.player1_plan.edge_index.shape[1]]*len(prediction),
|
||||||
|
[game.player2_plan.edge_index.shape[1]]*len(prediction),
|
||||||
|
[game.global_plan.edge_index.shape[1]]*len(prediction),
|
||||||
|
[sel]*len(prediction),
|
||||||
|
[game.game_path]*len(prediction),
|
||||||
|
[batch]*len(prediction)))
|
||||||
|
else:
|
||||||
|
ground_truth = ground_truth.to(device)
|
||||||
|
data.append((
|
||||||
|
torch.round(torch.sigmoid(prediction)).float().cpu().data.numpy(),
|
||||||
|
ground_truth.cpu().data.numpy(),
|
||||||
|
game.player1_plan.edge_index.shape[1],
|
||||||
|
game.player2_plan.edge_index.shape[1],
|
||||||
|
game.global_plan.edge_index.shape[1],
|
||||||
|
sel,
|
||||||
|
game.game_path,
|
||||||
|
batch,
|
||||||
|
))
|
||||||
|
loss = criterion(prediction, ground_truth)
|
||||||
|
acc_loss += loss.item()
|
||||||
|
acc_loss /= len(lst)
|
||||||
|
acc, f1, f1_list = print_epoch(data, acc_loss)
|
||||||
|
if not incremental:
|
||||||
|
data = [data[i] + (f1_list[i],) for i in range(len(data))]
|
||||||
|
return acc_loss, data, acc, f1, seq2seq_feats
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args, flush=True)
|
||||||
|
print(f'PID: {os.getpid():6d}', flush=True)
|
||||||
|
|
||||||
|
if isinstance(args.device, int) and args.device >= 0:
|
||||||
|
DEVICE = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
|
||||||
|
print(f'Using {DEVICE}')
|
||||||
|
else:
|
||||||
|
print('Device must be a zero or positive integer, but got',args.device)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if isinstance(args.seed, int) and args.seed >= 0:
|
||||||
|
seed = set_seed(args.seed)
|
||||||
|
else:
|
||||||
|
print('Seed must be a zero or positive integer, but got',args.seed)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
||||||
|
|
||||||
|
if args.use_dialogue=='Yes':
|
||||||
|
d_flag = True
|
||||||
|
elif args.use_dialogue=='No':
|
||||||
|
d_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.use_dialogue_moves=='Yes':
|
||||||
|
d_move_flag = True
|
||||||
|
elif args.use_dialogue_moves=='No':
|
||||||
|
d_move_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not args.experiment in list(range(9)):
|
||||||
|
print('Experiment must be in',list(range(9)),', but got',args.experiment)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not args.intermediate in list(range(32)):
|
||||||
|
print('Intermediate must be in',list(range(32)),', but got',args.intermediate)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.seq_model=='GRU':
|
||||||
|
seq_model = 0
|
||||||
|
elif args.seq_model=='LSTM':
|
||||||
|
seq_model = 1
|
||||||
|
elif args.seq_model=='Transformer':
|
||||||
|
seq_model = 2
|
||||||
|
else:
|
||||||
|
print('The sequence model must be in [GRU, LSTM, Transformer], but got', args.seq_model)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.plans=='Yes':
|
||||||
|
global_plan = (args.pov=='Third') or ((args.pov=='None') and (args.experiment in list(range(3))))
|
||||||
|
player_plan = (args.pov=='First') or ((args.pov=='None') and (args.experiment in list(range(3,9))))
|
||||||
|
elif args.plans=='No' or args.plans is None:
|
||||||
|
global_plan = False
|
||||||
|
player_plan = False
|
||||||
|
else:
|
||||||
|
print('Use Plan must be in [Yes, No], but got',args.plan)
|
||||||
|
exit()
|
||||||
|
print('global_plan', global_plan, 'player_plan', player_plan)
|
||||||
|
|
||||||
|
criterion = nn.BCEWithLogitsLoss()
|
||||||
|
|
||||||
|
model = Model(seq_model, DEVICE).to(DEVICE)
|
||||||
|
|
||||||
|
model.load_state_dict(torch.load(args.model_path))
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if args.pov=='None':
|
||||||
|
test = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
test += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
test = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
test = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
test += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
|
||||||
|
######### TEST
|
||||||
|
acc_loss, data, acc, f1, seq2seq_feats = do_split(model, test, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False)
|
||||||
|
with open(f'{args.model_path[:-6]}_feats_test.pkl', 'wb') as f:
|
||||||
|
pickle.dump(seq2seq_feats, f)
|
||||||
|
|
||||||
|
if args.pov=='None':
|
||||||
|
train = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
train += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
train = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
train = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
train += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
######### TRAIN
|
||||||
|
acc_loss, data, acc, f1, seq2seq_feats = do_split(model, train, args.experiment, criterion, device=DEVICE, global_plan=global_plan, player_plan=player_plan, incremental=False)
|
||||||
|
with open(f'{args.model_path[:-6]}_feats_train.pkl', 'wb') as f:
|
||||||
|
pickle.dump(seq2seq_feats, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||||
|
parser.add_argument('--pov', type=str, default='First',
|
||||||
|
help='point of view [None, First, Third]')
|
||||||
|
parser.add_argument('--use_dialogue', type=str, default='Yes',
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--use_dialogue_moves', type=str, default='Yes',
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--plans', type=str, default='Yes',
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--seq_model', type=str, default='Transformer',
|
||||||
|
help='point of view [GRU, LSTM, Transformer]')
|
||||||
|
parser.add_argument('--experiment', type=int, default=2,
|
||||||
|
help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]')
|
||||||
|
parser.add_argument('--intermediate', type=int,
|
||||||
|
help='')
|
||||||
|
parser.add_argument('--seed', type=int,
|
||||||
|
help='Selet random seed by index [0, 1, 2, ...]. 0 -> random seed set to 0. n>0 -> random seed '
|
||||||
|
'set to n\'th random number with original seed set to 0')
|
||||||
|
parser.add_argument('--device', type=int, default=0,
|
||||||
|
help='select cuda device number')
|
||||||
|
parser.add_argument('--model_path', type=str, default=None,
|
||||||
|
help='path to the pretrained model to be loaded')
|
||||||
|
|
||||||
|
main(parser.parse_args())
|
351
plan_predictor_oracle.py
Normal file
351
plan_predictor_oracle.py
Normal file
|
@ -0,0 +1,351 @@
|
||||||
|
import os
|
||||||
|
import torch, torch.nn as nn, numpy as np
|
||||||
|
from torch import optim
|
||||||
|
from random import shuffle
|
||||||
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
|
||||||
|
from src.data.game_parser import GameParser, make_splits, onehot, DEVICE, set_seed
|
||||||
|
from src.models.plan_model_oracle import Model
|
||||||
|
from src.models.losses import PlanLoss
|
||||||
|
import argparse
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def print_epoch(data,acc_loss,lst,exp, incremental=False):
|
||||||
|
print(f'{acc_loss:9.4f}',end='; ',flush=True)
|
||||||
|
acc = []
|
||||||
|
prec = []
|
||||||
|
rec = []
|
||||||
|
f1 = []
|
||||||
|
iou = []
|
||||||
|
total = []
|
||||||
|
predicts = []
|
||||||
|
targets = []
|
||||||
|
for x in data:
|
||||||
|
game = lst[x[2]]
|
||||||
|
game_mats = game.plan['materials']
|
||||||
|
pov_plan = game.plan[f'player{game.pov}']
|
||||||
|
pov_plan_mat = game.__dict__[f'player{game.pov}_plan_mat']
|
||||||
|
possible_mats = [game.materials_dict[x]-1 for x,_ in zip(game_mats[1:],pov_plan[1:])]
|
||||||
|
possible_cand = [game.materials_dict[x]-1 for x,y in zip(game_mats[1:],pov_plan[1:]) if y['make'] and y['make'][0][0]==-1]
|
||||||
|
possible_extra = [game.materials_dict[x]-1 for x,y in zip(game_mats[1:],pov_plan[1:]) if y['make'] and y['make'][0][0]>-1]
|
||||||
|
a, b = x[:2]
|
||||||
|
if exp == 3:
|
||||||
|
a = a.reshape(21,21)
|
||||||
|
for idx,aa in enumerate(a):
|
||||||
|
if idx in possible_extra:
|
||||||
|
cand_idxs = set([i for i,x in enumerate(pov_plan_mat[idx]) if x])
|
||||||
|
th, _ = zip(*sorted([(i, x) for i, x in enumerate(aa) if i in possible_mats], key=lambda x:x[1])[-2:])
|
||||||
|
if len(cand_idxs.intersection(set(th))):
|
||||||
|
for jdx, _ in enumerate(aa):
|
||||||
|
a[idx,jdx] = pov_plan_mat[idx,jdx]
|
||||||
|
else:
|
||||||
|
for jdx, _ in enumerate(aa):
|
||||||
|
a[idx,jdx] = 0
|
||||||
|
else:
|
||||||
|
for jdx, aaa in enumerate(aa):
|
||||||
|
a[idx,jdx] = 0
|
||||||
|
elif exp == 2:
|
||||||
|
a = a.reshape(21,21)
|
||||||
|
for idx,aa in enumerate(a):
|
||||||
|
if idx in possible_cand:
|
||||||
|
th = [x for i, x in enumerate(aa) if i in possible_mats]
|
||||||
|
th = sorted(th)
|
||||||
|
th = th[-2]
|
||||||
|
th = 1.1 if th < (1/21) else th
|
||||||
|
for jdx, aaa in enumerate(aa):
|
||||||
|
if idx in possible_mats:
|
||||||
|
a[idx,jdx] = 0 if aaa < th else 1
|
||||||
|
else:
|
||||||
|
a[idx,jdx] = 0
|
||||||
|
else:
|
||||||
|
for jdx, aaa in enumerate(aa):
|
||||||
|
a[idx,jdx] = 0
|
||||||
|
|
||||||
|
else:
|
||||||
|
a = a.reshape(21,21)
|
||||||
|
for idx,aa in enumerate(a):
|
||||||
|
th = sorted(aa)[-2]
|
||||||
|
th = 1.1 if th < (2.1/21) else th
|
||||||
|
for jdx, aaa in enumerate(aa):
|
||||||
|
a[idx,jdx] = 0 if aaa < th else 1
|
||||||
|
a = a.reshape(-1)
|
||||||
|
predicts.append(np.argmax(a))
|
||||||
|
targets.append(np.argmax(a) if np.argmax(a) in [x for x in b if x] else np.argmax(b))
|
||||||
|
acc.append(accuracy_score(a,b))
|
||||||
|
sa = set([i for i,x in enumerate(a) if x])
|
||||||
|
sb = set([i for i,x in enumerate(b) if x])
|
||||||
|
i = len(sa.intersection(sb))
|
||||||
|
u = len(sa.union(sb))
|
||||||
|
if u > 0:
|
||||||
|
a,b = zip(*[(x,y) for x,y in zip(a,b) if x+y > 0])
|
||||||
|
f1.append(f1_score(b,a,zero_division=1))
|
||||||
|
prec.append(precision_score(b,a,zero_division=1))
|
||||||
|
rec.append(recall_score(b,a,zero_division=1))
|
||||||
|
iou.append(i/u if u > 0 else 1)
|
||||||
|
total.append(sum(a))
|
||||||
|
print(
|
||||||
|
f'{np.mean(f1):5.3f},'
|
||||||
|
f'{np.mean(iou):5.3f},'
|
||||||
|
f'{np.std(iou):5.3f},',
|
||||||
|
end=' ',flush=True)
|
||||||
|
print('', end='; ',flush=True)
|
||||||
|
return accuracy_score(targets,predicts), np.mean(acc), np.mean(f1), np.mean(iou)
|
||||||
|
|
||||||
|
def do_split(model,lst,exp,criterion,optimizer=None,global_plan=False, player_plan=False, incremental=False, device=DEVICE, intermediate=0):
|
||||||
|
data = []
|
||||||
|
acc_loss = 0
|
||||||
|
p = []
|
||||||
|
g = []
|
||||||
|
masks = []
|
||||||
|
for batch, game in enumerate(lst):
|
||||||
|
|
||||||
|
if model.training and (not optimizer is None): optimizer.zero_grad()
|
||||||
|
|
||||||
|
if exp==0:
|
||||||
|
ground_truth = torch.tensor(game.global_plan_mat.reshape(-1)).float()
|
||||||
|
elif exp==1:
|
||||||
|
ground_truth = torch.tensor(game.partner_plan.reshape(-1)).float()
|
||||||
|
elif exp==2:
|
||||||
|
ground_truth = torch.tensor(game.global_diff_plan_mat.reshape(-1)).float()
|
||||||
|
loss_mask = torch.tensor(game.global_plan_mat.reshape(-1)).float()
|
||||||
|
else:
|
||||||
|
ground_truth = torch.tensor(game.partner_diff_plan_mat.reshape(-1)).float()
|
||||||
|
loss_mask = torch.tensor(game.plan_repr.reshape(-1)).float()
|
||||||
|
|
||||||
|
prediction, _ = model(game, global_plan=global_plan, player_plan=player_plan, incremental=incremental, intermediate=intermediate)
|
||||||
|
|
||||||
|
if incremental:
|
||||||
|
ground_truth = ground_truth.to(device)
|
||||||
|
g += [ground_truth for _ in prediction]
|
||||||
|
masks += [loss_mask for _ in prediction]
|
||||||
|
|
||||||
|
p += [x for x in prediction]
|
||||||
|
|
||||||
|
data += list(zip(prediction.cpu().data.numpy(), [ground_truth.cpu().data.numpy()]*len(prediction),[batch]*len(prediction)))
|
||||||
|
else:
|
||||||
|
ground_truth = ground_truth.to(device)
|
||||||
|
g.append(ground_truth)
|
||||||
|
masks.append(loss_mask)
|
||||||
|
|
||||||
|
p.append(prediction)
|
||||||
|
|
||||||
|
data.append((prediction.cpu().data.numpy(), ground_truth.cpu().data.numpy(),batch))
|
||||||
|
|
||||||
|
if (batch+1) % 2 == 0:
|
||||||
|
loss = criterion(torch.stack(p),torch.stack(g), torch.stack(masks))
|
||||||
|
|
||||||
|
loss += 1e-5 * sum(p.pow(2.0).sum() for p in model.parameters())
|
||||||
|
if model.training and (not optimizer is None):
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
acc_loss += loss.item()
|
||||||
|
p = []
|
||||||
|
g = []
|
||||||
|
masks = []
|
||||||
|
|
||||||
|
acc_loss /= len(lst)
|
||||||
|
|
||||||
|
acc0, acc, f1, iou = print_epoch(data,acc_loss,lst,exp)
|
||||||
|
|
||||||
|
return acc0, acc_loss, data, acc, f1, iou
|
||||||
|
|
||||||
|
|
||||||
|
def init_weights(m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
torch.nn.init.xavier_uniform_(m.weight)
|
||||||
|
m.bias.data.fill_(0.01)
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args, flush=True)
|
||||||
|
print(f'PID: {os.getpid():6d}', flush=True)
|
||||||
|
|
||||||
|
if isinstance(args.device, int) and args.device >= 0:
|
||||||
|
DEVICE = f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu'
|
||||||
|
print(f'Using {DEVICE}')
|
||||||
|
else:
|
||||||
|
print('Device must be a zero or positive integer, but got',args.device)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if isinstance(args.seed, int) and args.seed >= 0:
|
||||||
|
seed = set_seed(args.seed)
|
||||||
|
else:
|
||||||
|
print('Seed must be a zero or positive integer, but got',args.seed)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
# dataset_splits = make_splits('config/dataset_splits_dev.json')
|
||||||
|
dataset_splits = make_splits('config/dataset_splits_new.json')
|
||||||
|
|
||||||
|
if args.use_dialogue=='Yes':
|
||||||
|
d_flag = True
|
||||||
|
elif args.use_dialogue=='No':
|
||||||
|
d_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.use_dialogue_moves=='Yes':
|
||||||
|
d_move_flag = True
|
||||||
|
elif args.use_dialogue_moves=='No':
|
||||||
|
d_move_flag = False
|
||||||
|
else:
|
||||||
|
print('Use dialogue must be in [Yes, No], but got',args.use_dialogue)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not args.experiment in list(range(9)):
|
||||||
|
print('Experiment must be in',list(range(9)),', but got',args.experiment)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if not args.intermediate in list(range(32)):
|
||||||
|
print('Intermediate must be in',list(range(32)),', but got',args.intermediate)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
|
||||||
|
if args.seq_model=='GRU':
|
||||||
|
seq_model = 0
|
||||||
|
elif args.seq_model=='LSTM':
|
||||||
|
seq_model = 1
|
||||||
|
elif args.seq_model=='Transformer':
|
||||||
|
seq_model = 2
|
||||||
|
else:
|
||||||
|
print('The sequence model must be in [GRU, LSTM, Transformer], but got', args.seq_model)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
if args.plans=='Yes':
|
||||||
|
global_plan = (args.pov=='Third') or ((args.pov=='None') and (args.experiment in list(range(3))))
|
||||||
|
player_plan = (args.pov=='First') or ((args.pov=='None') and (args.experiment in list(range(3,9))))
|
||||||
|
elif args.plans=='No' or args.plans is None:
|
||||||
|
global_plan = False
|
||||||
|
player_plan = False
|
||||||
|
else:
|
||||||
|
print('Use Plan must be in [Yes, No], but got',args.plan)
|
||||||
|
exit()
|
||||||
|
print('global_plan', global_plan, 'player_plan', player_plan)
|
||||||
|
|
||||||
|
if args.pov=='None':
|
||||||
|
val = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
val += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
val = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
val = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
val += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['validation'])]
|
||||||
|
train += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['training'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
model = Model(seq_model,DEVICE).to(DEVICE)
|
||||||
|
model.apply(init_weights)
|
||||||
|
|
||||||
|
print(model)
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
learning_rate = 1e-5
|
||||||
|
weight_decay=1e-4
|
||||||
|
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
|
||||||
|
criterion = PlanLoss()
|
||||||
|
|
||||||
|
print(str(criterion), str(optimizer))
|
||||||
|
|
||||||
|
num_epochs = 200#1#
|
||||||
|
min_acc_loss = 1e6
|
||||||
|
max_f1 = 0
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
wait_epoch = 15
|
||||||
|
max_fails = 5
|
||||||
|
|
||||||
|
if args.model_path is not None:
|
||||||
|
print(f'Loading {args.model_path}')
|
||||||
|
model.load_state_dict(torch.load(args.model_path))
|
||||||
|
model.eval()
|
||||||
|
acc, acc_loss, data, _, f1, iou = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, incremental=True, device=DEVICE, intermediate=args.intermediate)
|
||||||
|
acc, acc_loss0, data, _, f1, iou = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, incremental=False, device=DEVICE, intermediate=args.intermediate)
|
||||||
|
|
||||||
|
if np.mean([acc_loss,acc_loss0]) < min_acc_loss:
|
||||||
|
min_acc_loss = np.mean([acc_loss,acc_loss0])
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
print('^')
|
||||||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print('Training model from scratch', flush=True)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
print(f'{os.getpid():6d} {epoch+1:4d},',end=' ',flush=True)
|
||||||
|
shuffle(train)
|
||||||
|
model.train()
|
||||||
|
do_split(model,train,args.experiment,criterion,optimizer=optimizer,global_plan=global_plan, player_plan=player_plan, incremental=True, device=DEVICE, intermediate=args.intermediate)
|
||||||
|
do_split(model,train,args.experiment,criterion,optimizer=optimizer,global_plan=global_plan, player_plan=player_plan, incremental=False, device=DEVICE, intermediate=args.intermediate)
|
||||||
|
model.eval()
|
||||||
|
acc, acc_loss, data, _, f1, iou = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, incremental=True, device=DEVICE, intermediate=args.intermediate)
|
||||||
|
acc, acc_loss0, data, _, f1, iou = do_split(model,val,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, incremental=False, device=DEVICE, intermediate=args.intermediate)
|
||||||
|
|
||||||
|
if np.mean([acc_loss,acc_loss0]) < min_acc_loss:
|
||||||
|
min_acc_loss = np.mean([acc_loss,acc_loss0])
|
||||||
|
epochs_since_improvement = 0
|
||||||
|
print('^')
|
||||||
|
torch.save(model.cpu().state_dict(), args.save_path)
|
||||||
|
model = model.to(DEVICE)
|
||||||
|
else:
|
||||||
|
epochs_since_improvement += 1
|
||||||
|
print()
|
||||||
|
|
||||||
|
if epoch > wait_epoch and epochs_since_improvement > max_fails:
|
||||||
|
break
|
||||||
|
print()
|
||||||
|
print('Test')
|
||||||
|
model.load_state_dict(torch.load(args.save_path))
|
||||||
|
|
||||||
|
val = None
|
||||||
|
train = None
|
||||||
|
if args.pov=='None':
|
||||||
|
test = [GameParser(f,d_flag,0,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
if args.experiment > 2:
|
||||||
|
test += [GameParser(f,d_flag,4,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='Third':
|
||||||
|
test = [GameParser(f,d_flag,3,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
elif args.pov=='First':
|
||||||
|
test = [GameParser(f,d_flag,1,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
test += [GameParser(f,d_flag,2,args.intermediate,d_move_flag) for f in tqdm(dataset_splits['test'])]
|
||||||
|
else:
|
||||||
|
print('POV must be in [None, First, Third], but got', args.pov)
|
||||||
|
model.eval()
|
||||||
|
acc, acc_loss, data, _, f1, iou = do_split(model,test,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, incremental=True, device=DEVICE, intermediate=args.intermediate)
|
||||||
|
acc, acc_loss, data, _, f1, iou = do_split(model,test,args.experiment,criterion,global_plan=global_plan, player_plan=player_plan, incremental=False, device=DEVICE, intermediate=args.intermediate)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(data)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='Process some integers.')
|
||||||
|
parser.add_argument('--pov', type=str,
|
||||||
|
help='point of view [None, First, Third]')
|
||||||
|
parser.add_argument('--use_dialogue', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--use_dialogue_moves', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--plans', type=str,
|
||||||
|
help='Use dialogue [Yes, No]')
|
||||||
|
parser.add_argument('--seq_model', type=str,
|
||||||
|
help='point of view [GRU, LSTM, Transformer]')
|
||||||
|
parser.add_argument('--experiment', type=int,
|
||||||
|
help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]')
|
||||||
|
parser.add_argument('--intermediate', type=int,
|
||||||
|
help='point of view [0:Global, 1:Partner, 2:GlobalDif, 3:PartnerDif]')
|
||||||
|
parser.add_argument('--save_path', type=str,
|
||||||
|
help='path where to save model')
|
||||||
|
parser.add_argument('--seed', type=int,
|
||||||
|
help='Selet random seed by index [0, 1, 2, ...]. 0 -> random seed set to 0. n>0 -> random seed '
|
||||||
|
'set to n\'th random number with original seed set to 0')
|
||||||
|
parser.add_argument('--device', type=int, default=0,
|
||||||
|
help='select cuda device number')
|
||||||
|
parser.add_argument('--model_path', type=str, default=None,
|
||||||
|
help='path to the pretrained model to be loaded')
|
||||||
|
|
||||||
|
main(parser.parse_args())
|
42
run_plan_predictor.sh
Normal file
42
run_plan_predictor.sh
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
echo $$
|
||||||
|
CUDA_DEVICE=$1
|
||||||
|
DMOVE=$2
|
||||||
|
DLG=$3
|
||||||
|
|
||||||
|
echo $$ $CUDA_DEVICE $DMOVE $DLG
|
||||||
|
|
||||||
|
FOLDER="models/baselines"
|
||||||
|
# FOLDER="models/test_loss"
|
||||||
|
# FOLDER="models/incremental_pretrained_2_attn_new"
|
||||||
|
mkdir -p $FOLDER
|
||||||
|
for SEED in 42; do #1 2 3 4 5; do #0; do #
|
||||||
|
for MODEL in LSTM; do # LSTM; do # Transformer; do #
|
||||||
|
for EXP in 3; do # 2 3; do
|
||||||
|
# for DMOVE in "No" "Yes"; do
|
||||||
|
# for DLG in "No" "Yes"; do
|
||||||
|
for INT in 0 1 2 3 4 5 6 7; do
|
||||||
|
FILE_NAME="plan_exp${EXP}_${MODEL}_dlg_${DLG}_move_${DMOVE}_int${INT}_seed_${SEED}"
|
||||||
|
COMM="plan_predictor.py"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DMOVE}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --pov=First"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --intermediate=${INT}"
|
||||||
|
if [ $INT -gt 0 ]; then
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/plan_exp${EXP}_${MODEL}_dlg_${DLG}_move_${DMOVE}_int0_seed_${SEED}.torch"
|
||||||
|
fi
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T)" python3 ${COMM} > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
done
|
||||||
|
# done
|
||||||
|
# done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
echo 'Done!'
|
47
run_plan_predictor_graphs.sh
Normal file
47
run_plan_predictor_graphs.sh
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
echo $$
|
||||||
|
CUDA_DEVICE=$1
|
||||||
|
DMOVE=$2
|
||||||
|
DLG=$3
|
||||||
|
LR=$4
|
||||||
|
|
||||||
|
echo $$ $CUDA_DEVICE $DMOVE $DLG
|
||||||
|
|
||||||
|
# FOLDER="models/incremental_pretrained_2_new"
|
||||||
|
# FOLDER="models/incremental_pretrained_2_new_fullsampling_ln"
|
||||||
|
# FOLDER="models/exp2"
|
||||||
|
# FOLDER="models/exp3"
|
||||||
|
FOLDER="models/exp3_test"
|
||||||
|
# FOLDER="models/exp2_test"
|
||||||
|
mkdir -p $FOLDER
|
||||||
|
for SEED in 7 42 123; do #1 2 3 4 5; do #0; do #
|
||||||
|
for MODEL in Transformer; do # LSTM; do # Transformer; do #
|
||||||
|
for EXP in 3; do # 2 3; do
|
||||||
|
# for DMOVE in "No" "Yes"; do
|
||||||
|
# for DLG in "No" "Yes"; do
|
||||||
|
for INT in 0 1 2 3 4 5 6 7; do
|
||||||
|
FILE_NAME="plan_exp${EXP}_${MODEL}_dlg_${DLG}_move_${DMOVE}_int${INT}_seed_${SEED}"
|
||||||
|
COMM="plan_predictor_graphs.py"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DMOVE}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --pov=First"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --lr=${LR}"
|
||||||
|
COMM=$COMM" --intermediate=${INT}"
|
||||||
|
if [ $INT -gt 0 ]; then
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/plan_exp${EXP}_${MODEL}_dlg_${DLG}_move_${DMOVE}_int0_seed_${SEED}.torch"
|
||||||
|
fi
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T)" python3 ${COMM} > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
done
|
||||||
|
# done
|
||||||
|
# done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
echo 'Done!'
|
44
run_plan_predictor_graphs_int0.sh
Normal file
44
run_plan_predictor_graphs_int0.sh
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
echo Traning on int0!
|
||||||
|
CUDA_DEVICE=$1
|
||||||
|
DMOVE=$2
|
||||||
|
DLG=$3
|
||||||
|
LR=$4
|
||||||
|
|
||||||
|
echo $$ $CUDA_DEVICE $DMOVE $DLG
|
||||||
|
|
||||||
|
# FOLDER="models/incremental_pretrained_2_new"
|
||||||
|
# FOLDER="models/incremental_pretrained_2_new_fullsampling_ln"
|
||||||
|
# FOLDER="models/exp2"
|
||||||
|
# FOLDER="models/exp3"
|
||||||
|
FOLDER="models/exp2_int0exp2"
|
||||||
|
# FOLDER="models/exp2_test"
|
||||||
|
mkdir -p $FOLDER
|
||||||
|
for SEED in 7 42 123; do
|
||||||
|
for MODEL in Transformer; do
|
||||||
|
for EXP in 2; do # 2 3; do
|
||||||
|
for INT in 0; do
|
||||||
|
FILE_NAME="plan_exp${EXP}_${MODEL}_dlg_${DLG}_move_${DMOVE}_int${INT}_seed_${SEED}"
|
||||||
|
COMM="plan_predictor_graphs.py"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DMOVE}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --pov=First"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --lr=${LR}"
|
||||||
|
COMM=$COMM" --use_int0_instead_of_intermediate"
|
||||||
|
COMM=$COMM" --intermediate=${INT}"
|
||||||
|
if [ $INT -gt 0 ]; then
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/plan_exp${EXP}_${MODEL}_dlg_${DLG}_move_${DMOVE}_int0_seed_${SEED}.torch"
|
||||||
|
fi
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T)" python3 ${COMM} > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
echo 'Done!'
|
36
run_plan_predictor_graphs_oracle.sh
Normal file
36
run_plan_predictor_graphs_oracle.sh
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
echo $$
|
||||||
|
CUDA_DEVICE=$1
|
||||||
|
DMOVE=$2
|
||||||
|
DLG=$3
|
||||||
|
|
||||||
|
echo $$ $CUDA_DEVICE $DMOVE $DLG
|
||||||
|
|
||||||
|
FOLDER="models/exp3_oracle"
|
||||||
|
mkdir -p $FOLDER
|
||||||
|
for SEED in 7 42 123; do
|
||||||
|
for MODEL in Transformer; do
|
||||||
|
for EXP in 3; do
|
||||||
|
for INT in 0 1 2 3 4 5 6 7; do
|
||||||
|
FILE_NAME="plan_exp${EXP}_${MODEL}_dlg_${DLG}_move_${DMOVE}_int${INT}_seed_${SEED}"
|
||||||
|
COMM="plan_predictor_graphs_oracle.py"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DMOVE}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --pov=First"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --intermediate=${INT}"
|
||||||
|
if [ $INT -gt 0 ]; then
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/plan_exp${EXP}_${MODEL}_dlg_${DLG}_move_${DMOVE}_int0_seed_${SEED}.torch"
|
||||||
|
fi
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T)" python3 ${COMM} > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
echo 'Done!'
|
42
run_plan_predictor_graphs_sigmoid.sh
Normal file
42
run_plan_predictor_graphs_sigmoid.sh
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
echo $$
|
||||||
|
CUDA_DEVICE=$1
|
||||||
|
DMOVE=$2
|
||||||
|
DLG=$3
|
||||||
|
|
||||||
|
echo $$ $CUDA_DEVICE $DMOVE $DLG
|
||||||
|
|
||||||
|
# FOLDER="models/incremental_pretrained_2"
|
||||||
|
# FOLDER="models/tests"
|
||||||
|
FOLDER="models/incremental_pretrained_2_sigmoid"
|
||||||
|
mkdir -p $FOLDER
|
||||||
|
for SEED in 7 42 73; do #1 2 3 4 5; do #0; do #
|
||||||
|
for MODEL in Transformer; do # LSTM; do # Transformer; do #
|
||||||
|
for EXP in 2; do # 2 3; do
|
||||||
|
# for DMOVE in "No" "Yes"; do
|
||||||
|
# for DLG in "No" "Yes"; do
|
||||||
|
for INT in 0 7; do
|
||||||
|
FILE_NAME="plan_exp${EXP}_${MODEL}_dlg_${DLG}_move_${DMOVE}_int${INT}_seed_${SEED}"
|
||||||
|
COMM="plan_predictor_graphs_sigmoid.py"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DMOVE}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --pov=First"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --intermediate=${INT}"
|
||||||
|
if [ $INT -gt 0 ]; then
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/plan_exp${EXP}_${MODEL}_dlg_${DLG}_move_${DMOVE}_int0_seed_${SEED}.torch"
|
||||||
|
fi
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T)" python3 ${COMM} > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
done
|
||||||
|
# done
|
||||||
|
# done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
echo 'Done!'
|
40
run_plan_predictor_oracle.sh
Normal file
40
run_plan_predictor_oracle.sh
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
echo $$
|
||||||
|
CUDA_DEVICE=$1
|
||||||
|
DMOVE=$2
|
||||||
|
DLG=$3
|
||||||
|
|
||||||
|
echo $$ $CUDA_DEVICE $DMOVE $DLG
|
||||||
|
|
||||||
|
FOLDER="models/baselines_oracle"
|
||||||
|
mkdir -p $FOLDER
|
||||||
|
for SEED in 123; do #1 2 3 4 5; do #0; do #
|
||||||
|
for MODEL in LSTM; do # LSTM; do # Transformer; do #
|
||||||
|
for EXP in 2; do # 2 3; do
|
||||||
|
# for DMOVE in "No" "Yes"; do
|
||||||
|
# for DLG in "No" "Yes"; do
|
||||||
|
for INT in 0 1 2 3 4 5 6 7; do
|
||||||
|
FILE_NAME="plan_exp${EXP}_${MODEL}_dlg_${DLG}_move_${DMOVE}_int${INT}_seed_${SEED}"
|
||||||
|
COMM="plan_predictor_oracle.py"
|
||||||
|
COMM=$COMM" --seed=${SEED}"
|
||||||
|
COMM=$COMM" --device=${CUDA_DEVICE}"
|
||||||
|
COMM=$COMM" --use_dialogue_moves=${DMOVE}"
|
||||||
|
COMM=$COMM" --use_dialogue=${DLG}"
|
||||||
|
COMM=$COMM" --experiment=${EXP}"
|
||||||
|
COMM=$COMM" --seq_model=${MODEL}"
|
||||||
|
COMM=$COMM" --pov=First"
|
||||||
|
COMM=$COMM" --plan=Yes"
|
||||||
|
COMM=$COMM" --intermediate=${INT}"
|
||||||
|
if [ $INT -gt 0 ]; then
|
||||||
|
COMM=$COMM" --model_path=${FOLDER}/plan_exp${EXP}_${MODEL}_dlg_${DLG}_move_${DMOVE}_int0_seed_${SEED}.torch"
|
||||||
|
fi
|
||||||
|
COMM=$COMM" --save_path=${FOLDER}/${FILE_NAME}.torch"
|
||||||
|
echo $(date +%F\ %T)" python3 ${COMM} > ${FOLDER}/${FILE_NAME}.log"
|
||||||
|
nice -n 5 python3 $COMM > ${FOLDER}/${FILE_NAME}.log
|
||||||
|
done
|
||||||
|
# done
|
||||||
|
# done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
done
|
||||||
|
|
||||||
|
echo 'Done!'
|
BIN
src/.DS_Store
vendored
Normal file
BIN
src/.DS_Store
vendored
Normal file
Binary file not shown.
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
BIN
src/data/.DS_Store
vendored
Normal file
BIN
src/data/.DS_Store
vendored
Normal file
Binary file not shown.
0
src/data/__init__.py
Normal file
0
src/data/__init__.py
Normal file
761
src/data/game_parser.py
Executable file
761
src/data/game_parser.py
Executable file
|
@ -0,0 +1,761 @@
|
||||||
|
from email.mime import base
|
||||||
|
from glob import glob
|
||||||
|
import os, string, json, pickle
|
||||||
|
import torch, random, numpy as np
|
||||||
|
from transformers import BertTokenizer, BertModel
|
||||||
|
import cv2
|
||||||
|
import imageio
|
||||||
|
from src.data.action_extractor import proc_action
|
||||||
|
|
||||||
|
|
||||||
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
# def set_seed(seed_idx):
|
||||||
|
# seed = 0
|
||||||
|
# random.seed(0)
|
||||||
|
# for _ in range(seed_idx):
|
||||||
|
# seed = random.random()
|
||||||
|
# random.seed(seed)
|
||||||
|
# torch.manual_seed(seed)
|
||||||
|
# print('Random seed set to', seed)
|
||||||
|
# return seed
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
print('Random seed set to', seed)
|
||||||
|
return seed
|
||||||
|
|
||||||
|
def make_splits(split_file = 'config/dataset_splits.json'):
|
||||||
|
if not os.path.isfile(split_file):
|
||||||
|
dirs = sorted(glob('data/saved_logs/*') + glob('data/main_logs/*'))
|
||||||
|
games = sorted(list(map(GameParser, dirs)), key=lambda x: len(x.question_pairs), reverse=True)
|
||||||
|
|
||||||
|
test = games[0::5]
|
||||||
|
val = games[1::5]
|
||||||
|
train = games[2::5]+games[3::5]+games[4::5]
|
||||||
|
|
||||||
|
dataset_splits = {'test' : [g.game_path for g in test], 'validation' : [g.game_path for g in val], 'training' : [g.game_path for g in train]}
|
||||||
|
json.dump(dataset_splits, open('config/dataset_splits_old.json','w'), indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
dirs = sorted(glob('data/new_logs/*'))
|
||||||
|
games = sorted(list(map(GameParser, dirs)), key=lambda x: len(x.question_pairs), reverse=True)
|
||||||
|
|
||||||
|
test = games[0::5]
|
||||||
|
val = games[1::5]
|
||||||
|
train = games[2::5]+games[3::5]+games[4::5]
|
||||||
|
|
||||||
|
dataset_splits['test'] += [g.game_path for g in test]
|
||||||
|
dataset_splits['validation'] += [g.game_path for g in val]
|
||||||
|
dataset_splits['training'] += [g.game_path for g in train]
|
||||||
|
json.dump(dataset_splits, open('config/dataset_splits_new.json','w'), indent=4)
|
||||||
|
json.dump(dataset_splits, open('config/dataset_splits.json','w'), indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
dataset_splits['test'] = dataset_splits['test'][:2]
|
||||||
|
dataset_splits['validation'] = dataset_splits['validation'][:2]
|
||||||
|
dataset_splits['training'] = dataset_splits['training'][:2]
|
||||||
|
json.dump(dataset_splits, open('config/dataset_splits_dev.json','w'), indent=4)
|
||||||
|
|
||||||
|
dataset_splits = json.load(open(split_file))
|
||||||
|
|
||||||
|
return dataset_splits
|
||||||
|
|
||||||
|
def onehot(x,n):
|
||||||
|
retval = np.zeros(n)
|
||||||
|
if x > 0:
|
||||||
|
retval[x-1] = 1
|
||||||
|
return retval
|
||||||
|
|
||||||
|
class GameParser:
|
||||||
|
tokenizer = None
|
||||||
|
model = None
|
||||||
|
def __init__(self, game_path, load_dialogue=True, pov=0, intermediate=0, use_dialogue_moves=False):
|
||||||
|
# print(game_path,end = ' ')
|
||||||
|
self.load_dialogue = load_dialogue
|
||||||
|
if pov not in (0,1,2,3,4):
|
||||||
|
print('Point of view must be in (0,1,2,3,4), but got ', pov)
|
||||||
|
exit()
|
||||||
|
self.pov = pov
|
||||||
|
self.use_dialogue_moves = use_dialogue_moves
|
||||||
|
self.load_player1 = pov==1
|
||||||
|
self.load_player2 = pov==2
|
||||||
|
self.load_third_person = pov==3
|
||||||
|
self.game_path = game_path
|
||||||
|
# print(game_path)
|
||||||
|
self.dialogue_file = glob(os.path.join(game_path,'mcc*log'))[0]
|
||||||
|
self.questions_file = glob(os.path.join(game_path,'web*log'))[0]
|
||||||
|
self.plan_file = glob(os.path.join(game_path,'plan*json'))[0]
|
||||||
|
self.plan = json.load(open(self.plan_file))
|
||||||
|
self.img_w = 96
|
||||||
|
self.img_h = 96
|
||||||
|
self.intermediate = intermediate
|
||||||
|
|
||||||
|
self.flip_video = False
|
||||||
|
for l in open(self.dialogue_file):
|
||||||
|
if 'HAS JOINED' in l:
|
||||||
|
player_name = l.strip().split()[1]
|
||||||
|
self.flip_video = player_name[-1] == '2'
|
||||||
|
break
|
||||||
|
|
||||||
|
if not os.path.isfile("config/materials.json") or \
|
||||||
|
not os.path.isfile("config/mines.json") or \
|
||||||
|
not os.path.isfile("config/tools.json"):
|
||||||
|
plan_files = sorted(glob('data/*_logs/*/plan*.json'))
|
||||||
|
materials = []
|
||||||
|
tools = []
|
||||||
|
mines = []
|
||||||
|
for plan_file in plan_files:
|
||||||
|
plan = json.load(open(plan_file))
|
||||||
|
materials += plan['materials']
|
||||||
|
tools += plan['tools']
|
||||||
|
mines += plan['mines']
|
||||||
|
materials = sorted(list(set(materials)))
|
||||||
|
tools = sorted(list(set(tools)))
|
||||||
|
mines = sorted(list(set(mines)))
|
||||||
|
json.dump(materials, open('config/materials.json','w'), indent=4)
|
||||||
|
json.dump(mines, open('config/mines.json','w'), indent=4)
|
||||||
|
json.dump(tools, open('config/tools.json','w'), indent=4)
|
||||||
|
|
||||||
|
materials = json.load(open('config/materials.json'))
|
||||||
|
mines = json.load(open('config/mines.json'))
|
||||||
|
tools = json.load(open('config/tools.json'))
|
||||||
|
|
||||||
|
self.materials_dict = {x:i+1 for i,x in enumerate(materials)}
|
||||||
|
self.mines_dict = {x:i+1 for i,x in enumerate(mines)}
|
||||||
|
self.tools_dict = {x:i+1 for i,x in enumerate(tools)}
|
||||||
|
|
||||||
|
self.__load_dialogue_act_labels()
|
||||||
|
self.__load_dialogue_move_labels()
|
||||||
|
self.__parse_dialogue()
|
||||||
|
self.__parse_questions()
|
||||||
|
self.__parse_start_end()
|
||||||
|
self.__parse_question_pairs()
|
||||||
|
self.__load_videos()
|
||||||
|
self.__assign_dialogue_act_labels()
|
||||||
|
self.__assign_dialogue_move_labels()
|
||||||
|
self.__load_replay_data()
|
||||||
|
self.__load_intermediate()
|
||||||
|
|
||||||
|
# print(len(self.materials_dict))
|
||||||
|
|
||||||
|
self.global_plan = []
|
||||||
|
self.global_plan_mat = np.zeros((21,21))
|
||||||
|
mine_counter = 0
|
||||||
|
for n,v in zip(self.plan['materials'],self.plan['full']):
|
||||||
|
if v['make']:
|
||||||
|
mine = 0
|
||||||
|
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
|
||||||
|
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
|
||||||
|
self.global_plan_mat[self.materials_dict[n]-1][m1-1] = 1
|
||||||
|
self.global_plan_mat[self.materials_dict[n]-1][m2-1] = 1
|
||||||
|
else:
|
||||||
|
mine = self.mines_dict[self.plan['mines'][mine_counter]]
|
||||||
|
mine_counter += 1
|
||||||
|
m1 = 0
|
||||||
|
m2 = 0
|
||||||
|
mine = onehot(mine, len(self.mines_dict))
|
||||||
|
m1 = onehot(m1,len(self.materials_dict))
|
||||||
|
m2 = onehot(m2,len(self.materials_dict))
|
||||||
|
mat = onehot(self.materials_dict[n],len(self.materials_dict))
|
||||||
|
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]],len(self.tools_dict))
|
||||||
|
step = np.concatenate((mat,m1,m2,mine,t))
|
||||||
|
self.global_plan.append(step)
|
||||||
|
|
||||||
|
self.player1_plan = []
|
||||||
|
self.player1_plan_mat = np.zeros((21,21))
|
||||||
|
mine_counter = 0
|
||||||
|
for n,v in zip(self.plan['materials'],self.plan['player1']):
|
||||||
|
if v['make']:
|
||||||
|
mine = 0
|
||||||
|
if v['make'][0][0] < 0:
|
||||||
|
m1 = 0
|
||||||
|
m2 = 0
|
||||||
|
else:
|
||||||
|
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
|
||||||
|
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
|
||||||
|
self.player1_plan_mat[self.materials_dict[n]-1][m1-1] = 1
|
||||||
|
self.player1_plan_mat[self.materials_dict[n]-1][m2-1] = 1
|
||||||
|
else:
|
||||||
|
mine = self.mines_dict[self.plan['mines'][mine_counter]]
|
||||||
|
mine_counter += 1
|
||||||
|
m1 = 0
|
||||||
|
m2 = 0
|
||||||
|
mine = onehot(mine, len(self.mines_dict))
|
||||||
|
m1 = onehot(m1,len(self.materials_dict))
|
||||||
|
m2 = onehot(m2,len(self.materials_dict))
|
||||||
|
mat = onehot(self.materials_dict[n],len(self.materials_dict))
|
||||||
|
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]],len(self.tools_dict))
|
||||||
|
step = np.concatenate((mat,m1,m2,mine,t))
|
||||||
|
self.player1_plan.append(step)
|
||||||
|
|
||||||
|
self.player2_plan = []
|
||||||
|
self.player2_plan_mat = np.zeros((21,21))
|
||||||
|
mine_counter = 0
|
||||||
|
for n,v in zip(self.plan['materials'],self.plan['player2']):
|
||||||
|
if v['make']:
|
||||||
|
mine = 0
|
||||||
|
if v['make'][0][0] < 0:
|
||||||
|
m1 = 0
|
||||||
|
m2 = 0
|
||||||
|
else:
|
||||||
|
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
|
||||||
|
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
|
||||||
|
self.player2_plan_mat[self.materials_dict[n]-1][m1-1] = 1
|
||||||
|
self.player2_plan_mat[self.materials_dict[n]-1][m2-1] = 1
|
||||||
|
else:
|
||||||
|
mine = self.mines_dict[self.plan['mines'][mine_counter]]
|
||||||
|
mine_counter += 1
|
||||||
|
m1 = 0
|
||||||
|
m2 = 0
|
||||||
|
mine = onehot(mine, len(self.mines_dict))
|
||||||
|
m1 = onehot(m1,len(self.materials_dict))
|
||||||
|
m2 = onehot(m2,len(self.materials_dict))
|
||||||
|
mat = onehot(self.materials_dict[n],len(self.materials_dict))
|
||||||
|
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]],len(self.tools_dict))
|
||||||
|
step = np.concatenate((mat,m1,m2,mine,t))
|
||||||
|
self.player2_plan.append(step)
|
||||||
|
# print(self.global_plan_mat.reshape(-1))
|
||||||
|
# print(self.player1_plan_mat.reshape(-1))
|
||||||
|
# print(self.player2_plan_mat.reshape(-1))
|
||||||
|
# for x in zip(self.global_plan_mat.reshape(-1),self.player1_plan_mat.reshape(-1),self.player2_plan_mat.reshape(-1)):
|
||||||
|
# if sum(x) > 0:
|
||||||
|
# print(x)
|
||||||
|
# exit()
|
||||||
|
if self.load_player1:
|
||||||
|
self.plan_repr = self.player1_plan_mat
|
||||||
|
self.partner_plan = self.player2_plan_mat
|
||||||
|
elif self.load_player2:
|
||||||
|
self.plan_repr = self.player2_plan_mat
|
||||||
|
self.partner_plan = self.player1_plan_mat
|
||||||
|
else:
|
||||||
|
self.plan_repr = self.global_plan_mat
|
||||||
|
self.partner_plan = self.global_plan_mat
|
||||||
|
self.global_diff_plan_mat = self.global_plan_mat - self.plan_repr
|
||||||
|
self.partner_diff_plan_mat = self.global_plan_mat - self.partner_plan
|
||||||
|
|
||||||
|
self.__iter_ts = self.start_ts
|
||||||
|
|
||||||
|
# self.action_labels = sorted([t for a in self.actions for t in a if t.PacketData in ['BlockChangeData']], key=lambda x: x.TickIndex)
|
||||||
|
self.action_labels = None
|
||||||
|
# for tick in ticks:
|
||||||
|
# print(int(tick.TickIndex/30), self.plan['materials'].index( tick.items[0]), int(tick.Name[-1]))
|
||||||
|
# print(self.start_ts, self.end_ts, self.start_ts - self.end_ts, int(ticks[-1].TickIndex/30) if ticks else 0,self.action_file)
|
||||||
|
# exit()
|
||||||
|
self.materials = sorted(self.plan['materials'])
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.end_ts - self.start_ts
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.__iter_ts < self.end_ts:
|
||||||
|
|
||||||
|
if self.load_dialogue:
|
||||||
|
d = [x for x in self.dialogue_events if x[0] == self.__iter_ts]
|
||||||
|
l = [x for x in self.dialogue_act_labels if x[0] == self.__iter_ts]
|
||||||
|
d = d if d else None
|
||||||
|
l = l if l else None
|
||||||
|
else:
|
||||||
|
d = None
|
||||||
|
l = None
|
||||||
|
|
||||||
|
if self.use_dialogue_moves:
|
||||||
|
m = [x for x in self.dialogue_move_labels if x[0] == self.__iter_ts]
|
||||||
|
m = m if m else None
|
||||||
|
else:
|
||||||
|
m = None
|
||||||
|
|
||||||
|
if self.action_labels:
|
||||||
|
a = [x for x in self.action_labels if (x.TickIndex//30 + self.start_ts) >= self.__iter_ts]
|
||||||
|
if a:
|
||||||
|
try:
|
||||||
|
while not a[0].items:
|
||||||
|
a = a[1:]
|
||||||
|
al = self.materials.index(a[0].items[0]) if a else 0
|
||||||
|
except Exception:
|
||||||
|
print(a)
|
||||||
|
print(a[0])
|
||||||
|
print(a[0].items)
|
||||||
|
print(a[0].items[0])
|
||||||
|
exit()
|
||||||
|
at = a[0].TickIndex//30 + self.start_ts
|
||||||
|
an = int(a[0].Name[-1])
|
||||||
|
a = [(at,al,an)]
|
||||||
|
else:
|
||||||
|
a = [(self.__iter_ts, self.materials.index(self.plan['materials'][0]), 1)]
|
||||||
|
a = None
|
||||||
|
else:
|
||||||
|
if self.end_ts - self.__iter_ts < 10:
|
||||||
|
# a = [(self.__iter_ts, self.materials.index(self.plan['materials'][0]), 1)]
|
||||||
|
a = None
|
||||||
|
else:
|
||||||
|
a = None
|
||||||
|
# if not self.__iter_ts % 30 == 0:
|
||||||
|
# a= None
|
||||||
|
if not a is None:
|
||||||
|
if not a[0][0] == self.__iter_ts:
|
||||||
|
a = None
|
||||||
|
|
||||||
|
# q = [x for x in self.question_pairs if (x[0][0] < self.__iter_ts) and (x[0][1] > self.__iter_ts)]
|
||||||
|
q = [x for x in self.question_pairs if (x[0][1] == self.__iter_ts)]
|
||||||
|
q = q[0] if q else None
|
||||||
|
frame_idx = self.__iter_ts - self.start_ts
|
||||||
|
if self.load_third_person:
|
||||||
|
frames = self.third_pers_frames
|
||||||
|
elif self.load_player1:
|
||||||
|
frames = self.player1_pov_frames
|
||||||
|
elif self.load_player2:
|
||||||
|
frames = self.player2_pov_frames
|
||||||
|
else:
|
||||||
|
frames = np.array([0])
|
||||||
|
if len(frames) == 1:
|
||||||
|
f = np.zeros((self.img_h,self.img_w,3))
|
||||||
|
else:
|
||||||
|
if frame_idx < frames.shape[0]:
|
||||||
|
f = frames[frame_idx]
|
||||||
|
else:
|
||||||
|
f = np.zeros((self.img_h,self.img_w,3))
|
||||||
|
if self.do_upperbound:
|
||||||
|
if not q is None:
|
||||||
|
qnum = 0
|
||||||
|
base_rep = np.concatenate([
|
||||||
|
onehot(q[0][2],2),
|
||||||
|
onehot(q[0][3],2),
|
||||||
|
onehot(q[0][4][qnum][0]+1,2),
|
||||||
|
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||||
|
onehot(q[0][4][qnum][0]+1,2),
|
||||||
|
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||||
|
onehot(['YES','MAYBE','NO'].index(q[1][0][qnum])+1,3),
|
||||||
|
onehot(['YES','MAYBE','NO'].index(q[1][1][qnum])+1,3)
|
||||||
|
])
|
||||||
|
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||||
|
ToM6 = base_rep if self.ToM6 is not None else np.zeros(1024)
|
||||||
|
qnum = 1
|
||||||
|
base_rep = np.concatenate([
|
||||||
|
onehot(q[0][2],2),
|
||||||
|
onehot(q[0][3],2),
|
||||||
|
onehot(q[0][4][qnum][0]+1,2),
|
||||||
|
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||||
|
onehot(q[0][4][qnum][0]+1,2),
|
||||||
|
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||||
|
onehot(['YES','MAYBE','NO'].index(q[1][0][qnum])+1,3),
|
||||||
|
onehot(['YES','MAYBE','NO'].index(q[1][1][qnum])+1,3)
|
||||||
|
])
|
||||||
|
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||||
|
ToM7 = base_rep if self.ToM7 is not None else np.zeros(1024)
|
||||||
|
qnum = 2
|
||||||
|
base_rep = np.concatenate([
|
||||||
|
onehot(q[0][2],2),
|
||||||
|
onehot(q[0][3],2),
|
||||||
|
onehot(q[0][4][qnum]+1,2),
|
||||||
|
onehot(q[0][4][qnum]+1,2),
|
||||||
|
onehot(self.materials_dict[q[1][0][qnum]] if q[1][0][qnum] in self.materials_dict else len(self.materials_dict)+1,len(self.materials_dict)+1),
|
||||||
|
onehot(self.materials_dict[q[1][1][qnum]] if q[1][1][qnum] in self.materials_dict else len(self.materials_dict)+1,len(self.materials_dict)+1)
|
||||||
|
])
|
||||||
|
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||||
|
ToM8 = base_rep if self.ToM8 is not None else np.zeros(1024)
|
||||||
|
else:
|
||||||
|
ToM6 = np.zeros(1024)
|
||||||
|
ToM7 = np.zeros(1024)
|
||||||
|
ToM8 = np.zeros(1024)
|
||||||
|
if not l is None:
|
||||||
|
base_rep = np.concatenate([
|
||||||
|
onehot(l[0][1],2),
|
||||||
|
onehot(l[0][2],len(self.dialogue_act_labels_dict))
|
||||||
|
])
|
||||||
|
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||||
|
DAct = base_rep if self.DAct is not None else np.zeros(1024)
|
||||||
|
else:
|
||||||
|
DAct = np.zeros(1024)
|
||||||
|
if not m is None:
|
||||||
|
base_rep = np.concatenate([
|
||||||
|
onehot(m[0][1],2),
|
||||||
|
onehot(m[0][2][0],len(self.dialogue_move_labels_dict)),
|
||||||
|
onehot(m[0][2][1],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||||
|
onehot(m[0][2][2],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||||
|
onehot(m[0][2][3],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||||
|
])
|
||||||
|
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||||
|
DMove = base_rep if self.DMove is not None else np.zeros(1024)
|
||||||
|
else:
|
||||||
|
DMove = np.zeros(1024)
|
||||||
|
else:
|
||||||
|
ToM6 = self.ToM6[frame_idx] if self.ToM6 is not None else np.zeros(1024)
|
||||||
|
ToM7 = self.ToM7[frame_idx] if self.ToM7 is not None else np.zeros(1024)
|
||||||
|
ToM8 = self.ToM8[frame_idx] if self.ToM8 is not None else np.zeros(1024)
|
||||||
|
DAct = self.DAct[frame_idx] if self.DAct is not None else np.zeros(1024)
|
||||||
|
DMove = self.DAct[frame_idx] if self.DMove is not None else np.zeros(1024)
|
||||||
|
# if not m is None:
|
||||||
|
# base_rep = np.concatenate([
|
||||||
|
# onehot(m[0][1],2),
|
||||||
|
# onehot(m[0][2][0],len(self.dialogue_move_labels_dict)),
|
||||||
|
# onehot(m[0][2][1],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||||
|
# onehot(m[0][2][2],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||||
|
# onehot(m[0][2][3],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||||
|
# ])
|
||||||
|
# base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||||
|
# DMove = base_rep if self.DMove is not None else np.zeros(1024)
|
||||||
|
# else:
|
||||||
|
# DMove = np.zeros(1024)
|
||||||
|
intermediate = np.concatenate([ToM6,ToM7,ToM8,DAct,DMove])
|
||||||
|
retval = ((self.__iter_ts,self.pov),d,l,q,f,a,intermediate,m)
|
||||||
|
self.__iter_ts += 1
|
||||||
|
return retval
|
||||||
|
self.__iter_ts = self.start_ts
|
||||||
|
raise StopIteration()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __load_videos(self):
|
||||||
|
d = self.end_ts - self.start_ts
|
||||||
|
|
||||||
|
if self.load_third_person:
|
||||||
|
try:
|
||||||
|
self.third_pers_file = glob(os.path.join(self.game_path,'third*gif'))[0]
|
||||||
|
np_file = self.third_pers_file[:-3]+'npz'
|
||||||
|
if os.path.isfile(np_file):
|
||||||
|
self.third_pers_frames = np.load(np_file)['data']
|
||||||
|
else:
|
||||||
|
frames = imageio.get_reader(self.third_pers_file, '.gif')
|
||||||
|
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
|
||||||
|
if 'main' in self.game_path:
|
||||||
|
self.third_pers_frames = np.array([reshaper(f[95:4*95,250:-249,2::-1]) for f in frames])
|
||||||
|
else:
|
||||||
|
self.third_pers_frames = np.array([reshaper(f[-3*95:,250:-249,2::-1]) for f in frames])
|
||||||
|
print(np_file,end=' ')
|
||||||
|
np.savez_compressed(open(np_file,'wb'), data=self.third_pers_frames)
|
||||||
|
print('saved')
|
||||||
|
except Exception as e:
|
||||||
|
self.third_pers_frames = np.array([0])
|
||||||
|
|
||||||
|
if self.third_pers_frames.shape[0]//d < 10:
|
||||||
|
self.third_pov_frame_rate = 6
|
||||||
|
else:
|
||||||
|
if self.third_pers_frames.shape[0]//d < 20:
|
||||||
|
self.third_pov_frame_rate = 12
|
||||||
|
else:
|
||||||
|
if self.third_pers_frames.shape[0]//d < 45:
|
||||||
|
self.third_pov_frame_rate = 30
|
||||||
|
else:
|
||||||
|
self.third_pov_frame_rate = 60
|
||||||
|
self.third_pers_frames = self.third_pers_frames[::self.third_pov_frame_rate]
|
||||||
|
else:
|
||||||
|
self.third_pers_frames = np.array([0])
|
||||||
|
|
||||||
|
if self.load_player1:
|
||||||
|
try:
|
||||||
|
search_str = 'play2*gif' if self.flip_video else 'play1*gif'
|
||||||
|
self.player1_pov_file = glob(os.path.join(self.game_path,search_str))[0]
|
||||||
|
np_file = self.player1_pov_file[:-3]+'npz'
|
||||||
|
if os.path.isfile(np_file):
|
||||||
|
self.player1_pov_frames = np.load(np_file)['data']
|
||||||
|
else:
|
||||||
|
frames = imageio.get_reader(self.player1_pov_file, '.gif')
|
||||||
|
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
|
||||||
|
self.player1_pov_frames = np.array([reshaper(f[:,:,2::-1]) for f in frames])
|
||||||
|
print(np_file,end=' ')
|
||||||
|
np.savez_compressed(open(np_file,'wb'), data=self.player1_pov_frames)
|
||||||
|
print('saved')
|
||||||
|
except Exception as e:
|
||||||
|
self.player1_pov_frames = np.array([0])
|
||||||
|
|
||||||
|
if self.player1_pov_frames.shape[0]//d < 10:
|
||||||
|
self.player1_pov_frame_rate = 6
|
||||||
|
else:
|
||||||
|
if self.player1_pov_frames.shape[0]//d < 20:
|
||||||
|
self.player1_pov_frame_rate = 12
|
||||||
|
else:
|
||||||
|
if self.player1_pov_frames.shape[0]//d < 45:
|
||||||
|
self.player1_pov_frame_rate = 30
|
||||||
|
else:
|
||||||
|
self.player1_pov_frame_rate = 60
|
||||||
|
self.player1_pov_frames = self.player1_pov_frames[::self.player1_pov_frame_rate]
|
||||||
|
else:
|
||||||
|
self.player1_pov_frames = np.array([0])
|
||||||
|
|
||||||
|
if self.load_player2:
|
||||||
|
try:
|
||||||
|
search_str = 'play1*gif' if self.flip_video else 'play2*gif'
|
||||||
|
self.player2_pov_file = glob(os.path.join(self.game_path,search_str))[0]
|
||||||
|
np_file = self.player2_pov_file[:-3]+'npz'
|
||||||
|
if os.path.isfile(np_file):
|
||||||
|
self.player2_pov_frames = np.load(np_file)['data']
|
||||||
|
else:
|
||||||
|
frames = imageio.get_reader(self.player2_pov_file, '.gif')
|
||||||
|
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
|
||||||
|
self.player2_pov_frames = np.array([reshaper(f[:,:,2::-1]) for f in frames])
|
||||||
|
print(np_file,end=' ')
|
||||||
|
np.savez_compressed(open(np_file,'wb'), data=self.player2_pov_frames)
|
||||||
|
print('saved')
|
||||||
|
except Exception as e:
|
||||||
|
self.player2_pov_frames = np.array([0])
|
||||||
|
|
||||||
|
if self.player2_pov_frames.shape[0]//d < 10:
|
||||||
|
self.player2_pov_frame_rate = 6
|
||||||
|
else:
|
||||||
|
if self.player2_pov_frames.shape[0]//d < 20:
|
||||||
|
self.player2_pov_frame_rate = 12
|
||||||
|
else:
|
||||||
|
if self.player2_pov_frames.shape[0]//d < 45:
|
||||||
|
self.player2_pov_frame_rate = 30
|
||||||
|
else:
|
||||||
|
self.player2_pov_frame_rate = 60
|
||||||
|
self.player2_pov_frames = self.player2_pov_frames[::self.player2_pov_frame_rate]
|
||||||
|
else:
|
||||||
|
self.player2_pov_frames = np.array([0])
|
||||||
|
|
||||||
|
def __parse_question_pairs(self):
|
||||||
|
question_dict = {}
|
||||||
|
for q in self.questions:
|
||||||
|
k = q[2][0][1] + q[2][1][1]
|
||||||
|
if not k in question_dict:
|
||||||
|
question_dict[k] = []
|
||||||
|
question_dict[k].append(q)
|
||||||
|
|
||||||
|
self.question_pairs = []
|
||||||
|
for k,v in question_dict.items():
|
||||||
|
if len(v) == 2:
|
||||||
|
if v[0][1]+v[1][1] == 3:
|
||||||
|
self.question_pairs.append(v)
|
||||||
|
else:
|
||||||
|
while len(v) > 1:
|
||||||
|
pair = []
|
||||||
|
pair.append(v.pop(0))
|
||||||
|
pair.append(v.pop(0))
|
||||||
|
while not pair[0][1]+pair[1][1] == 3:
|
||||||
|
if not v:
|
||||||
|
break
|
||||||
|
# print(game_path,pair)
|
||||||
|
pair.append(v.pop(0))
|
||||||
|
pair.pop(0)
|
||||||
|
if not v:
|
||||||
|
break
|
||||||
|
self.question_pairs.append(pair)
|
||||||
|
self.question_pairs = sorted(self.question_pairs, key=lambda x: x[0][0])
|
||||||
|
if self.load_player2 or self.pov==4:
|
||||||
|
self.question_pairs = [sorted(q, key=lambda x: x[1],reverse=True) for q in self.question_pairs]
|
||||||
|
else:
|
||||||
|
self.question_pairs = [sorted(q, key=lambda x: x[1]) for q in self.question_pairs]
|
||||||
|
|
||||||
|
|
||||||
|
self.question_pairs = [((a[0], b[0], a[1], b[1], a[2], b[2]), (a[3], b[3])) for a,b in self.question_pairs]
|
||||||
|
|
||||||
|
def __parse_dialogue(self):
|
||||||
|
self.dialogue_events = []
|
||||||
|
# if not self.load_dialogue:
|
||||||
|
# return
|
||||||
|
save_path = os.path.join(self.game_path,f'dialogue_{self.game_path.split("/")[-1]}.pkl')
|
||||||
|
# print(save_path)
|
||||||
|
# exit()
|
||||||
|
if os.path.isfile(save_path):
|
||||||
|
self.dialogue_events = pickle.load(open( save_path, "rb" ))
|
||||||
|
return
|
||||||
|
for x in open(self.dialogue_file):
|
||||||
|
if '[Async Chat Thread' in x:
|
||||||
|
ts = list(map(int,x.split(' [')[0].strip('[]').split(':')))
|
||||||
|
ts = 3600*ts[0] + 60*ts[1] + ts[2]
|
||||||
|
player, event = x.strip().split('/INFO]: []<sledmcc')[1].split('> ',1)
|
||||||
|
event = event.lower()
|
||||||
|
event = ''.join([x if x in string.ascii_lowercase else f' {x} ' for x in event]).strip()
|
||||||
|
event = event.replace(' ',' ').replace(' ',' ')
|
||||||
|
player = int(player)
|
||||||
|
if GameParser.tokenizer is None:
|
||||||
|
GameParser.tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=True)
|
||||||
|
if self.model is None:
|
||||||
|
GameParser.model = BertModel.from_pretrained('bert-large-uncased', output_hidden_states=True).to(DEVICE)
|
||||||
|
encoded_dict = GameParser.tokenizer.encode_plus(
|
||||||
|
event, # Sentence to encode.
|
||||||
|
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
|
||||||
|
return_tensors='pt', # Return pytorch tensors.
|
||||||
|
)
|
||||||
|
token_ids = encoded_dict['input_ids'].to(DEVICE)
|
||||||
|
segment_ids = torch.ones(token_ids.size()).long().to(DEVICE)
|
||||||
|
GameParser.model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = GameParser.model(input_ids=token_ids, token_type_ids=segment_ids)
|
||||||
|
outputs = outputs[1][0].cpu().data.numpy()
|
||||||
|
self.dialogue_events.append((ts,player,event,outputs))
|
||||||
|
pickle.dump(self.dialogue_events, open( save_path, "wb" ))
|
||||||
|
print(f'Saved to {save_path}',flush=True)
|
||||||
|
|
||||||
|
def __parse_questions(self):
|
||||||
|
self.questions = []
|
||||||
|
for x in open(self.questions_file):
|
||||||
|
if x[0] == '#':
|
||||||
|
ts, qs = x.strip().split(' Number of records inserted: 1 # player')
|
||||||
|
# print(ts,qs)
|
||||||
|
|
||||||
|
ts = list(map(int,ts.split(' ')[5].split(':')))
|
||||||
|
ts = 3600*ts[0] + 60*ts[1] + ts[2]
|
||||||
|
|
||||||
|
player = int(qs[0])
|
||||||
|
questions = qs[2:].split(';')
|
||||||
|
answers =[x[7:] for x in questions[3:]]
|
||||||
|
questions = [x[9:].split(' ') for x in questions[:3]]
|
||||||
|
questions[0] = (int(questions[0][0] == 'Have'), questions[0][-3])
|
||||||
|
questions[1] = (int(questions[1][2] == 'know'), questions[1][-1])
|
||||||
|
questions[2] = int(questions[2][1] == 'are')
|
||||||
|
|
||||||
|
self.questions.append((ts,player,questions,answers))
|
||||||
|
def __parse_start_end(self):
|
||||||
|
self.start_ts = [x.strip() for x in open(self.dialogue_file) if 'THEY ARE PLAYER' in x][1]
|
||||||
|
self.start_ts = list(map(int,self.start_ts.split('] [')[0][1:].split(':')))
|
||||||
|
self.start_ts = 3600*self.start_ts[0] + 60*self.start_ts[1] + self.start_ts[2]
|
||||||
|
try:
|
||||||
|
self.start_ts = max(self.start_ts, self.questions[0][0]-75)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self.end_ts = [x.strip() for x in open(self.dialogue_file) if 'Stopping' in x]
|
||||||
|
if self.end_ts:
|
||||||
|
self.end_ts = self.end_ts[0]
|
||||||
|
self.end_ts = list(map(int,self.end_ts.split('] [')[0][1:].split(':')))
|
||||||
|
self.end_ts = 3600*self.end_ts[0] + 60*self.end_ts[1] + self.end_ts[2]
|
||||||
|
else:
|
||||||
|
self.end_ts = self.dialogue_events[-1][0]
|
||||||
|
try:
|
||||||
|
self.end_ts = max(self.end_ts, self.questions[-1][0]) + 1
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __load_dialogue_act_labels(self):
|
||||||
|
file_name = 'config/dialogue_act_labels.json'
|
||||||
|
if not os.path.isfile(file_name):
|
||||||
|
files = sorted(glob('/home/*/MCC/*done.txt'))
|
||||||
|
dialogue_act_dict = {}
|
||||||
|
for file in files:
|
||||||
|
game_str = ''
|
||||||
|
for line in open(file):
|
||||||
|
line = line.strip()
|
||||||
|
if '_logs/' in line:
|
||||||
|
game_str = line
|
||||||
|
else:
|
||||||
|
if line:
|
||||||
|
line = line.split()
|
||||||
|
key = f'{game_str}#{line[0]}'
|
||||||
|
dialogue_act_dict[key] = line[-1]
|
||||||
|
json.dump(dialogue_act_dict,open(file_name,'w'), indent=4)
|
||||||
|
self.dialogue_act_dict = json.load(open(file_name))
|
||||||
|
self.dialogue_act_labels_dict = {l : i for i, l in enumerate(sorted(list(set(self.dialogue_act_dict.values()))))}
|
||||||
|
self.dialogue_act_bias = {l : sum([int(x==l) for x in self.dialogue_act_dict.values()]) for l in self.dialogue_act_labels_dict.keys()}
|
||||||
|
json.dump(self.dialogue_act_labels_dict,open('config/dialogue_act_label_names.json','w'), indent=4)
|
||||||
|
# print(self.dialogue_act_bias)
|
||||||
|
# print(self.dialogue_act_labels_dict)
|
||||||
|
# exit()
|
||||||
|
|
||||||
|
def __assign_dialogue_act_labels(self):
|
||||||
|
|
||||||
|
# log_file = glob('/'.join([self.game_path,'mcc*log']))[0][5:]
|
||||||
|
log_file = glob('/'.join([self.game_path,'mcc*log']))[0].split('mindcraft/')[1]
|
||||||
|
self.dialogue_act_labels = []
|
||||||
|
for emb in self.dialogue_events:
|
||||||
|
ts = emb[0]
|
||||||
|
h = ts//3600
|
||||||
|
m = (ts%3600)//60
|
||||||
|
s = ts%60
|
||||||
|
key = f'{log_file}#[{h:02d}:{m:02d}:{s:02d}]:{emb[1]}>'
|
||||||
|
self.dialogue_act_labels.append((emb[0],emb[1],self.dialogue_act_labels_dict[self.dialogue_act_dict[key]]))
|
||||||
|
|
||||||
|
def __load_dialogue_move_labels(self):
|
||||||
|
file_name = "config/dialogue_move_labels.json"
|
||||||
|
dialogue_move_dict = {}
|
||||||
|
if not os.path.isfile(file_name):
|
||||||
|
file_text = ''
|
||||||
|
dialogue_moves = set()
|
||||||
|
for line in open("XXX"):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if line[0] == '#':
|
||||||
|
continue
|
||||||
|
if line[0] == '[':
|
||||||
|
tag_text = glob(f'data/*/*/mcc_{file_text}.log')[0].split('/',1)[-1]
|
||||||
|
key = f'{tag_text}#{line.split()[0]}'
|
||||||
|
value = line.split()[-1].split('#')
|
||||||
|
if len(value) < 4:
|
||||||
|
value += ['IGNORE']*(4-len(value))
|
||||||
|
dialogue_moves.add(value[0])
|
||||||
|
value = '#'.join(value)
|
||||||
|
dialogue_move_dict[key] = value
|
||||||
|
# print(key,value)
|
||||||
|
# break
|
||||||
|
else:
|
||||||
|
file_text = line
|
||||||
|
# print(line)
|
||||||
|
dialogue_moves = sorted(list(dialogue_moves))
|
||||||
|
# print(dialogue_moves)
|
||||||
|
|
||||||
|
json.dump(dialogue_move_dict,open(file_name,'w'), indent=4)
|
||||||
|
self.dialogue_move_dict = json.load(open(file_name))
|
||||||
|
self.dialogue_move_labels_dict = {l : i for i, l in enumerate(sorted(list(set([lbl.split('#')[0] for lbl in self.dialogue_move_dict.values()]))))}
|
||||||
|
self.dialogue_move_bias = {l : sum([int(x==l) for x in self.dialogue_move_dict.values()]) for l in self.dialogue_move_labels_dict.keys()}
|
||||||
|
json.dump(self.dialogue_move_labels_dict,open('config/dialogue_move_label_names.json','w'), indent=4)
|
||||||
|
|
||||||
|
def __assign_dialogue_move_labels(self):
|
||||||
|
|
||||||
|
# log_file = glob('/'.join([self.game_path,'mcc*log']))[0][5:]
|
||||||
|
log_file = glob('/'.join([self.game_path,'mcc*log']))[0].split('mindcraft/')[1]
|
||||||
|
self.dialogue_move_labels = []
|
||||||
|
for emb in self.dialogue_events:
|
||||||
|
ts = emb[0]
|
||||||
|
h = ts//3600
|
||||||
|
m = (ts%3600)//60
|
||||||
|
s = ts%60
|
||||||
|
key = f'{log_file}#[{h:02d}:{m:02d}:{s:02d}]:{emb[1]}>'
|
||||||
|
move = self.dialogue_move_dict[key].split('#')
|
||||||
|
move[0] = self.dialogue_move_labels_dict[move[0]]
|
||||||
|
for i,m in enumerate(move[1:]):
|
||||||
|
if m == 'IGNORE':
|
||||||
|
move[i+1] = 0
|
||||||
|
elif m in self.materials_dict:
|
||||||
|
move[i+1] = self.materials_dict[m]
|
||||||
|
elif m in self.mines_dict:
|
||||||
|
move[i+1] = self.mines_dict[m] + len(self.materials_dict)
|
||||||
|
elif m in self.tools_dict:
|
||||||
|
move[i+1] = self.tools_dict[m] + len(self.materials_dict) + len(self.mines_dict)
|
||||||
|
else:
|
||||||
|
print(move)
|
||||||
|
exit()
|
||||||
|
# print(move,self.dialogue_move_dict[key],key)
|
||||||
|
# exit()
|
||||||
|
self.dialogue_move_labels.append((emb[0],emb[1],move))
|
||||||
|
|
||||||
|
def __load_replay_data(self):
|
||||||
|
# self.action_file = "data/ReplayData/ActionsData_mcc_" + self.game_path.split('/')[-1]
|
||||||
|
# with open(self.action_file) as f:
|
||||||
|
# data = ' '.join(x.strip() for x in f).split('action')
|
||||||
|
# # preface = data[0]
|
||||||
|
# self.actions = list(map(proc_action, data[1:]))
|
||||||
|
self.actions = None
|
||||||
|
|
||||||
|
def __load_intermediate(self):
|
||||||
|
if self.intermediate > 15:
|
||||||
|
self.do_upperbound = True
|
||||||
|
else:
|
||||||
|
self.do_upperbound = False
|
||||||
|
if self.pov in [1,2]:
|
||||||
|
self.ToM6 = np.load(glob(f'{self.game_path}/intermediate_baseline_ToM6*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||||
|
self.intermediate = self.intermediate // 2
|
||||||
|
self.ToM7 = np.load(glob(f'{self.game_path}/intermediate_baseline_ToM7*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||||
|
self.intermediate = self.intermediate // 2
|
||||||
|
self.ToM8 = np.load(glob(f'{self.game_path}/intermediate_baseline_ToM8*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||||
|
self.intermediate = self.intermediate // 2
|
||||||
|
self.DAct = np.load(glob(f'{self.game_path}/intermediate_baseline_DAct*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||||
|
self.intermediate = self.intermediate // 2
|
||||||
|
self.DMove = None
|
||||||
|
# print(self.ToM6)
|
||||||
|
# print(self.ToM7)
|
||||||
|
# print(self.ToM8)
|
||||||
|
# print(self.DAct)
|
||||||
|
else:
|
||||||
|
self.ToM6 = None
|
||||||
|
self.ToM7 = None
|
||||||
|
self.ToM8 = None
|
||||||
|
self.DAct = None
|
||||||
|
self.DMove = None
|
||||||
|
# exit()
|
||||||
|
|
837
src/data/game_parser_graphs_new.py
Normal file
837
src/data/game_parser_graphs_new.py
Normal file
|
@ -0,0 +1,837 @@
|
||||||
|
from glob import glob
|
||||||
|
import os, string, json, pickle
|
||||||
|
import torch, random, numpy as np
|
||||||
|
from transformers import BertTokenizer, BertModel
|
||||||
|
import cv2
|
||||||
|
import imageio
|
||||||
|
import networkx as nx
|
||||||
|
from torch_geometric.utils.convert import from_networkx
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from torch_geometric.utils import degree
|
||||||
|
|
||||||
|
|
||||||
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
|
|
||||||
|
# def set_seed(seed_idx):
|
||||||
|
# seed = 0
|
||||||
|
# random.seed(0)
|
||||||
|
# for _ in range(seed_idx):
|
||||||
|
# seed = random.random()
|
||||||
|
# random.seed(seed)
|
||||||
|
# torch.manual_seed(seed)
|
||||||
|
# print('Random seed set to', seed)
|
||||||
|
# return seed
|
||||||
|
|
||||||
|
def set_seed(seed):
|
||||||
|
os.environ['PYTHONHASHSEED'] = str(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
print('Random seed set to', seed)
|
||||||
|
return seed
|
||||||
|
|
||||||
|
def make_splits(split_file = 'config/dataset_splits.json'):
|
||||||
|
if not os.path.isfile(split_file):
|
||||||
|
dirs = sorted(glob('data/saved_logs/*') + glob('data/main_logs/*'))
|
||||||
|
games = sorted(list(map(GameParser, dirs)), key=lambda x: len(x.question_pairs), reverse=True)
|
||||||
|
|
||||||
|
test = games[0::5]
|
||||||
|
val = games[1::5]
|
||||||
|
train = games[2::5]+games[3::5]+games[4::5]
|
||||||
|
|
||||||
|
dataset_splits = {'test' : [g.game_path for g in test], 'validation' : [g.game_path for g in val], 'training' : [g.game_path for g in train]}
|
||||||
|
json.dump(dataset_splits, open('config/dataset_splits_old.json','w'), indent=4)
|
||||||
|
|
||||||
|
dirs = sorted(glob('data/new_logs/*'))
|
||||||
|
games = sorted(list(map(GameParser, dirs)), key=lambda x: len(x.question_pairs), reverse=True)
|
||||||
|
|
||||||
|
test = games[0::5]
|
||||||
|
val = games[1::5]
|
||||||
|
train = games[2::5]+games[3::5]+games[4::5]
|
||||||
|
|
||||||
|
dataset_splits['test'] += [g.game_path for g in test]
|
||||||
|
dataset_splits['validation'] += [g.game_path for g in val]
|
||||||
|
dataset_splits['training'] += [g.game_path for g in train]
|
||||||
|
json.dump(dataset_splits, open('config/dataset_splits_new.json','w'), indent=4)
|
||||||
|
json.dump(dataset_splits, open('config/dataset_splits.json','w'), indent=4)
|
||||||
|
|
||||||
|
dataset_splits['test'] = dataset_splits['test'][:2]
|
||||||
|
dataset_splits['validation'] = dataset_splits['validation'][:2]
|
||||||
|
dataset_splits['training'] = dataset_splits['training'][:2]
|
||||||
|
json.dump(dataset_splits, open('config/dataset_splits_dev.json','w'), indent=4)
|
||||||
|
|
||||||
|
dataset_splits = json.load(open(split_file))
|
||||||
|
|
||||||
|
return dataset_splits
|
||||||
|
|
||||||
|
def onehot(x,n):
|
||||||
|
retval = np.zeros(n)
|
||||||
|
if x > 0:
|
||||||
|
retval[x-1] = 1
|
||||||
|
return retval
|
||||||
|
|
||||||
|
class GameParser:
|
||||||
|
tokenizer = None
|
||||||
|
model = None
|
||||||
|
def __init__(self, game_path, load_dialogue=True, pov=0, intermediate=0, use_dialogue_moves=False, load_int0_feats=False):
|
||||||
|
self.load_dialogue = load_dialogue
|
||||||
|
if pov not in (0,1,2,3,4):
|
||||||
|
print('Point of view must be in (0,1,2,3,4), but got ', pov)
|
||||||
|
exit()
|
||||||
|
self.pov = pov
|
||||||
|
self.use_dialogue_moves = use_dialogue_moves
|
||||||
|
self.load_player1 = pov==1
|
||||||
|
self.load_player2 = pov==2
|
||||||
|
self.load_third_person = pov==3
|
||||||
|
self.game_path = game_path
|
||||||
|
self.dialogue_file = glob(os.path.join(game_path,'mcc*log'))[0]
|
||||||
|
self.questions_file = glob(os.path.join(game_path,'web*log'))[0]
|
||||||
|
self.plan_file = glob(os.path.join(game_path,'plan*json'))[0]
|
||||||
|
self.plan = json.load(open(self.plan_file))
|
||||||
|
self.img_w = 96
|
||||||
|
self.img_h = 96
|
||||||
|
self.intermediate = intermediate
|
||||||
|
|
||||||
|
self.flip_video = False
|
||||||
|
for l in open(self.dialogue_file):
|
||||||
|
if 'HAS JOINED' in l:
|
||||||
|
player_name = l.strip().split()[1]
|
||||||
|
self.flip_video = player_name[-1] == '2'
|
||||||
|
break
|
||||||
|
|
||||||
|
if not os.path.isfile("config/materials.json") or \
|
||||||
|
not os.path.isfile("config/mines.json") or \
|
||||||
|
not os.path.isfile("config/tools.json"):
|
||||||
|
plan_files = sorted(glob('data/*_logs/*/plan*.json'))
|
||||||
|
materials = []
|
||||||
|
tools = []
|
||||||
|
mines = []
|
||||||
|
for plan_file in plan_files:
|
||||||
|
plan = json.load(open(plan_file))
|
||||||
|
materials += plan['materials']
|
||||||
|
tools += plan['tools']
|
||||||
|
mines += plan['mines']
|
||||||
|
materials = sorted(list(set(materials)))
|
||||||
|
tools = sorted(list(set(tools)))
|
||||||
|
mines = sorted(list(set(mines)))
|
||||||
|
json.dump(materials, open('config/materials.json','w'), indent=4)
|
||||||
|
json.dump(mines, open('config/mines.json','w'), indent=4)
|
||||||
|
json.dump(tools, open('config/tools.json','w'), indent=4)
|
||||||
|
|
||||||
|
materials = json.load(open('config/materials.json'))
|
||||||
|
mines = json.load(open('config/mines.json'))
|
||||||
|
tools = json.load(open('config/tools.json'))
|
||||||
|
|
||||||
|
self.materials_dict = {x:i+1 for i,x in enumerate(materials)}
|
||||||
|
self.mines_dict = {x:i+1 for i,x in enumerate(mines)}
|
||||||
|
self.tools_dict = {x:i+1 for i,x in enumerate(tools)}
|
||||||
|
|
||||||
|
# NOTE new
|
||||||
|
shift_value = max(self.materials_dict.values())
|
||||||
|
self.materials_mines_dict = {**self.materials_dict, **{key: value + shift_value for key, value in self.mines_dict.items()}}
|
||||||
|
self.inverse_materials_mines_dict = {v: k for k, v in self.materials_mines_dict.items()}
|
||||||
|
#
|
||||||
|
|
||||||
|
self.__load_dialogue_act_labels()
|
||||||
|
self.__load_dialogue_move_labels()
|
||||||
|
self.__parse_dialogue()
|
||||||
|
self.__parse_questions()
|
||||||
|
self.__parse_start_end()
|
||||||
|
self.__parse_question_pairs()
|
||||||
|
self.__load_videos()
|
||||||
|
self.__assign_dialogue_act_labels()
|
||||||
|
self.__assign_dialogue_move_labels()
|
||||||
|
self.__load_replay_data()
|
||||||
|
self.__load_intermediate()
|
||||||
|
self.load_int0 = load_int0_feats
|
||||||
|
if load_int0_feats:
|
||||||
|
self.__load_int0_feats()
|
||||||
|
|
||||||
|
#############################################
|
||||||
|
################## GRAPHS ###################
|
||||||
|
|
||||||
|
#############################################
|
||||||
|
################ Global Plan ################
|
||||||
|
self.global_plan = nx.DiGraph()
|
||||||
|
mine_counter = 0
|
||||||
|
for n, v in zip(self.plan['materials'], self.plan['full']):
|
||||||
|
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||||
|
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
|
||||||
|
self.global_plan = self._add_node(self.global_plan, n, features=mat)
|
||||||
|
if v['make']:
|
||||||
|
#print(n, v, self.plan['materials'][v['make'][0][0]], self.plan['materials'][v['make'][0][1]])
|
||||||
|
mine = 0
|
||||||
|
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
|
||||||
|
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
|
||||||
|
m1 = onehot(m1, len(self.materials_mines_dict))
|
||||||
|
m2 = onehot(m2, len(self.materials_mines_dict))
|
||||||
|
self.global_plan = self._add_node(self.global_plan, self.plan['materials'][v['make'][0][0]], features=m1)
|
||||||
|
self.global_plan = self._add_node(self.global_plan, self.plan['materials'][v['make'][0][1]], features=m2)
|
||||||
|
# m1 -> mat
|
||||||
|
self.global_plan = self._add_edge(self.global_plan, self.plan['materials'][v['make'][0][0]], n, tool=t)
|
||||||
|
# m2 -> mat
|
||||||
|
self.global_plan = self._add_edge(self.global_plan, self.plan['materials'][v['make'][0][1]], n, tool=t)
|
||||||
|
else:
|
||||||
|
#print(n, v, self.plan['mines'][mine_counter])
|
||||||
|
mine = self.materials_mines_dict[self.plan['mines'][mine_counter]]
|
||||||
|
mine_counter += 1
|
||||||
|
mine = onehot(mine, len(self.materials_mines_dict))
|
||||||
|
self.global_plan = self._add_node(self.global_plan, self.plan['mines'][mine_counter], features=mine)
|
||||||
|
self.global_plan = self._add_edge(self.global_plan, self.plan['mines'][mine_counter], n, tool=t)
|
||||||
|
#self._plot_plan_graph(self.global_plan, filename=f"plots/global_{game_path.split('/')[-2]}_{game_path.split('/')[-1]}.png")
|
||||||
|
self.global_plan = from_networkx(self.global_plan) # NOTE: I modified /torch_geometric/utils/convert.py, line 250
|
||||||
|
#############################################
|
||||||
|
|
||||||
|
#############################################
|
||||||
|
############### Player 1 Plan ###############
|
||||||
|
self.player1_plan = nx.DiGraph()
|
||||||
|
mine_counter = 0
|
||||||
|
for n,v in zip(self.plan['materials'], self.plan['player1']):
|
||||||
|
if v['make']:
|
||||||
|
mine = 0
|
||||||
|
if v['make'][0][0] < 0:
|
||||||
|
#print(n, v, "unknown", "unknown")
|
||||||
|
m1 = 0
|
||||||
|
m2 = 0
|
||||||
|
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||||
|
self.player1_plan = self._add_node(self.player1_plan, n, features=mat)
|
||||||
|
else:
|
||||||
|
#print(n, v, self.plan['materials'][v['make'][0][0]], self.plan['materials'][v['make'][0][1]])
|
||||||
|
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||||
|
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
|
||||||
|
self.player1_plan = self._add_node(self.player1_plan, n, features=mat)
|
||||||
|
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
|
||||||
|
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
|
||||||
|
m1 = onehot(m1, len(self.materials_mines_dict))
|
||||||
|
m2 = onehot(m2, len(self.materials_mines_dict))
|
||||||
|
self.player1_plan = self._add_node(self.player1_plan, self.plan['materials'][v['make'][0][0]], features=m1)
|
||||||
|
self.player1_plan = self._add_node(self.player1_plan, self.plan['materials'][v['make'][0][1]], features=m2)
|
||||||
|
# m1 -> mat
|
||||||
|
self.player1_plan = self._add_edge(self.player1_plan, self.plan['materials'][v['make'][0][0]], n, tool=t)
|
||||||
|
# m2 -> mat
|
||||||
|
self.player1_plan = self._add_edge(self.player1_plan, self.plan['materials'][v['make'][0][1]], n, tool=t)
|
||||||
|
else:
|
||||||
|
#print(n, v, self.plan['mines'][mine_counter])
|
||||||
|
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||||
|
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
|
||||||
|
self.player1_plan = self._add_node(self.player1_plan, n, features=mat)
|
||||||
|
mine = self.materials_mines_dict[self.plan['mines'][mine_counter]]
|
||||||
|
mine_counter += 1
|
||||||
|
mine = onehot(mine, len(self.materials_mines_dict))
|
||||||
|
self.player1_plan = self._add_node(self.player1_plan, self.plan['mines'][mine_counter], features=mine)
|
||||||
|
self.player1_plan = self._add_edge(self.player1_plan, self.plan['mines'][mine_counter], n, tool=t)
|
||||||
|
#self._plot_plan_graph(self.player1_plan, filename=f"plots/player1_{game_path.split('/')[-2]}_{game_path.split('/')[-1]}.png")
|
||||||
|
self.player1_plan = from_networkx(self.player1_plan)
|
||||||
|
#############################################
|
||||||
|
|
||||||
|
#############################################
|
||||||
|
############### Player 2 Plan ###############
|
||||||
|
self.player2_plan = nx.DiGraph()
|
||||||
|
mine_counter = 0
|
||||||
|
for n,v in zip(self.plan['materials'], self.plan['player2']):
|
||||||
|
if v['make']:
|
||||||
|
mine = 0
|
||||||
|
if v['make'][0][0] < 0:
|
||||||
|
#print(n, v, "unknown", "unknown")
|
||||||
|
m1 = 0
|
||||||
|
m2 = 0
|
||||||
|
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||||
|
self.player2_plan = self._add_node(self.player2_plan, n, features=mat)
|
||||||
|
else:
|
||||||
|
#print(n, v, self.plan['materials'][v['make'][0][0]], self.plan['materials'][v['make'][0][1]])
|
||||||
|
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||||
|
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
|
||||||
|
self.player2_plan = self._add_node(self.player2_plan, n, features=mat)
|
||||||
|
m1 = self.materials_dict[self.plan['materials'][v['make'][0][0]]]
|
||||||
|
m2 = self.materials_dict[self.plan['materials'][v['make'][0][1]]]
|
||||||
|
m1 = onehot(m1, len(self.materials_mines_dict))
|
||||||
|
m2 = onehot(m2, len(self.materials_mines_dict))
|
||||||
|
self.player2_plan = self._add_node(self.player2_plan, self.plan['materials'][v['make'][0][0]], features=m1)
|
||||||
|
self.player2_plan = self._add_node(self.player2_plan, self.plan['materials'][v['make'][0][1]], features=m2)
|
||||||
|
# m1 -> mat
|
||||||
|
self.player2_plan = self._add_edge(self.player2_plan, self.plan['materials'][v['make'][0][0]], n, tool=t)
|
||||||
|
# m2 -> mat
|
||||||
|
self.player2_plan = self._add_edge(self.player2_plan, self.plan['materials'][v['make'][0][1]], n, tool=t)
|
||||||
|
else:
|
||||||
|
#print(n, v, self.plan['mines'][mine_counter])
|
||||||
|
mat = onehot(self.materials_mines_dict[n],len(self.materials_mines_dict))
|
||||||
|
t = onehot(self.tools_dict[self.plan['tools'][v['tools'][0]]], len(self.tools_dict))
|
||||||
|
self.player2_plan = self._add_node(self.player2_plan, n, features=mat)
|
||||||
|
mine = self.materials_mines_dict[self.plan['mines'][mine_counter]]
|
||||||
|
mine_counter += 1
|
||||||
|
mine = onehot(mine, len(self.materials_mines_dict))
|
||||||
|
self.player2_plan = self._add_node(self.player2_plan, self.plan['mines'][mine_counter], features=mine)
|
||||||
|
self.player2_plan = self._add_edge(self.player2_plan, self.plan['mines'][mine_counter], n, tool=t)
|
||||||
|
#self._plot_plan_graph(self.player2_plan, filename=f"plots/player2_{game_path.split('/')[-2]}_{game_path.split('/')[-1]}.png")
|
||||||
|
self.player2_plan = from_networkx(self.player2_plan)
|
||||||
|
|
||||||
|
# with open('graphs.pkl', 'wb') as f:
|
||||||
|
# pickle.dump([self.global_plan, self.player1_plan, self.player2_plan], f)
|
||||||
|
|
||||||
|
# construct a dict mapping materials to node indexes for each graph
|
||||||
|
p1_dict = {self.inverse_materials_mines_dict[torch.argmax(features).item()+1]: node_index for node_index, features in enumerate(self.player1_plan.features)}
|
||||||
|
p2_dict = {self.inverse_materials_mines_dict[torch.argmax(features).item()+1]: node_index for node_index, features in enumerate(self.player2_plan.features)}
|
||||||
|
# candidate edge = (u,v)
|
||||||
|
# u is from nodes with no out degree, v is from nodes with no in degree
|
||||||
|
p1_u_candidates = [p1_dict[i] for i in self.find_nodes_with_less_than_four_out_degree(self.player1_plan)]
|
||||||
|
p1_v_candidates = [p1_dict[i] for i in self.find_nodes_with_no_in_degree(self.player1_plan)]
|
||||||
|
p2_u_candidates = [p2_dict[i] for i in self.find_nodes_with_less_than_four_out_degree(self.player2_plan)]
|
||||||
|
p2_v_candidates = [p2_dict[i] for i in self.find_nodes_with_no_in_degree(self.player2_plan)]
|
||||||
|
# convert candidates to indexes
|
||||||
|
p1_edge_candidates = torch.tensor([(start, end) for start in p1_u_candidates for end in p1_v_candidates])
|
||||||
|
p2_edge_candidates = torch.tensor([(start, end) for start in p2_u_candidates for end in p2_v_candidates])
|
||||||
|
# find missing edges
|
||||||
|
gl_edges = [[self.inverse_materials_mines_dict[torch.argmax(self.global_plan.features[edge[0]]).item()+1], self.inverse_materials_mines_dict[torch.argmax(self.global_plan.features[edge[1]]).item()+1]] for edge in self.global_plan.edge_index.t().tolist()]
|
||||||
|
p1_edges = [[self.inverse_materials_mines_dict[torch.argmax(self.player1_plan.features[edge[0]]).item()+1], self.inverse_materials_mines_dict[torch.argmax(self.player1_plan.features[edge[1]]).item()+1]] for edge in self.player1_plan.edge_index.t().tolist()]
|
||||||
|
p2_edges = [[self.inverse_materials_mines_dict[torch.argmax(self.player2_plan.features[edge[0]]).item()+1], self.inverse_materials_mines_dict[torch.argmax(self.player2_plan.features[edge[1]]).item()+1]] for edge in self.player2_plan.edge_index.t().tolist()]
|
||||||
|
p1_missing_edges = [list(sublist) for sublist in set(map(tuple, gl_edges)) - set(map(tuple, p1_edges))]
|
||||||
|
p2_missing_edges = [list(sublist) for sublist in set(map(tuple, gl_edges)) - set(map(tuple, p2_edges))]
|
||||||
|
# convert missing edges as indexes
|
||||||
|
p1_missing_edges_idx = torch.tensor([(p1_dict[e[0]], p1_dict[e[1]]) for e in p1_missing_edges])
|
||||||
|
p2_missing_edges_idx = torch.tensor([(p2_dict[e[0]], p2_dict[e[1]]) for e in p2_missing_edges])
|
||||||
|
# check if all missing edges are present in the candidates
|
||||||
|
assert all(any(torch.equal(element, row) for row in p1_edge_candidates) for element in p1_missing_edges_idx)
|
||||||
|
assert all(any(torch.equal(element, row) for row in p2_edge_candidates) for element in p2_missing_edges_idx)
|
||||||
|
# concat candidates to plan graph
|
||||||
|
if p1_edge_candidates.numel() != 0:
|
||||||
|
self.player1_edge_label_index = torch.cat([self.player1_plan.edge_index, p1_edge_candidates.permute(1, 0)], dim=-1)
|
||||||
|
# create labels
|
||||||
|
self.player1_edge_label_own_missing_knowledge = torch.cat((torch.ones(self.player1_plan.edge_index.shape[1]), torch.zeros(p1_edge_candidates.shape[0])))
|
||||||
|
else:
|
||||||
|
# no missing knowledge
|
||||||
|
self.player1_edge_label_index = self.player1_plan.edge_index
|
||||||
|
# create labels
|
||||||
|
self.player1_edge_label_own_missing_knowledge = torch.ones(self.player1_plan.edge_index.shape[1])
|
||||||
|
if p2_edge_candidates.numel() != 0:
|
||||||
|
self.player2_edge_label_index = torch.cat([self.player2_plan.edge_index, p2_edge_candidates.permute(1, 0)], dim=-1)
|
||||||
|
# create labels
|
||||||
|
self.player2_edge_label_own_missing_knowledge = torch.cat((torch.ones(self.player2_plan.edge_index.shape[1]), torch.zeros(p2_edge_candidates.shape[0])))
|
||||||
|
else:
|
||||||
|
# no missing knowledge
|
||||||
|
self.player2_edge_label_index = self.player2_plan.edge_index
|
||||||
|
# create labels
|
||||||
|
self.player2_edge_label_own_missing_knowledge = torch.ones(self.player2_plan.edge_index.shape[1])
|
||||||
|
p1_edge_list = [tuple(x) for x in self.player1_edge_label_index.T.tolist()]
|
||||||
|
p1_missing_edges_idx_list = [tuple(x) for x in p1_missing_edges_idx.tolist()]
|
||||||
|
self.player1_edge_label_own_missing_knowledge[[p1_edge_list.index(x) for x in p1_missing_edges_idx_list]] = 1.
|
||||||
|
p2_edge_list = [tuple(x) for x in self.player2_edge_label_index.T.tolist()]
|
||||||
|
p2_missing_edges_idx_list = [tuple(x) for x in p2_missing_edges_idx.tolist()]
|
||||||
|
self.player2_edge_label_own_missing_knowledge[[p2_edge_list.index(x) for x in p2_missing_edges_idx_list]] = 1.
|
||||||
|
# compute other's missing knowledge == identify which one of my edges is unknown to the other player
|
||||||
|
p1_original_edges_list = [tuple(x) for x in self.player1_plan.edge_index.T.tolist()]
|
||||||
|
p2_original_edges_list = [tuple(x) for x in self.player2_plan.edge_index.T.tolist()]
|
||||||
|
p1_other_missing_edges_idx = [(p1_dict[e[0]], p1_dict[e[1]]) for e in p2_missing_edges] # note here is p2_missing_edges
|
||||||
|
p2_other_missing_edges_idx = [(p2_dict[e[0]], p2_dict[e[1]]) for e in p1_missing_edges] # note here is p1_missing_edges
|
||||||
|
self.player1_edge_label_other_missing_knowledge = torch.zeros(self.player1_plan.edge_index.shape[1])
|
||||||
|
self.player1_edge_label_other_missing_knowledge[[p1_original_edges_list.index(x) for x in p1_other_missing_edges_idx]] = 1.
|
||||||
|
self.player2_edge_label_other_missing_knowledge = torch.zeros(self.player2_plan.edge_index.shape[1])
|
||||||
|
self.player2_edge_label_other_missing_knowledge[[p2_original_edges_list.index(x) for x in p2_other_missing_edges_idx]] = 1.
|
||||||
|
|
||||||
|
self.__iter_ts = self.start_ts
|
||||||
|
|
||||||
|
self.action_labels = None
|
||||||
|
self.materials = sorted(self.plan['materials'])
|
||||||
|
|
||||||
|
def _add_node(self, g, material, features):
|
||||||
|
if material not in g.nodes:
|
||||||
|
#print(f'Add node {material}')
|
||||||
|
g.add_node(material, features=features)
|
||||||
|
return g
|
||||||
|
|
||||||
|
def _add_edge(self, g, u, v, tool):
|
||||||
|
if not g.has_edge(u, v):
|
||||||
|
#print(f'Add edge ({u}, {v})')
|
||||||
|
g.add_edge(u, v, tool=tool)
|
||||||
|
return g
|
||||||
|
|
||||||
|
def _plot_plan_graph(self, g, filename):
|
||||||
|
plt.figure(figsize=(20,20))
|
||||||
|
pos = nx.spring_layout(g, seed=42)
|
||||||
|
nx.draw(g, pos, with_labels=True, node_color='lightblue', edge_color='gray')
|
||||||
|
plt.savefig(filename)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def find_nodes_with_less_than_four_out_degree(self, data):
|
||||||
|
edge_index = data.edge_index
|
||||||
|
num_nodes = data.num_nodes
|
||||||
|
degrees = degree(edge_index[0], num_nodes) # out degrees
|
||||||
|
# find all nodes that have out degree less than 2
|
||||||
|
nodes = torch.nonzero(degrees < 4).view(-1)
|
||||||
|
nodes = [self.inverse_materials_mines_dict[torch.argmax(data.features[i]).item()+1] for i in nodes]
|
||||||
|
# remove planks (bc all planks have out degree less than 2)
|
||||||
|
nodes = [n for n in nodes if n.split('_')[-1] != 'PLANKS']
|
||||||
|
# now check for planks with out degree 0
|
||||||
|
check_zero_out_degree_planks = torch.nonzero(degrees < 1).view(-1)
|
||||||
|
check_zero_out_degree_planks = [self.inverse_materials_mines_dict[torch.argmax(data.features[i]).item()+1] for i in check_zero_out_degree_planks]
|
||||||
|
check_zero_out_degree_planks = [n for n in nodes if n.split('_')[-1] == 'PLANKS']
|
||||||
|
nodes = nodes + check_zero_out_degree_planks
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
def find_nodes_with_no_in_degree(self, data):
|
||||||
|
edge_index = data.edge_index
|
||||||
|
num_nodes = data.num_nodes
|
||||||
|
degrees = degree(edge_index[1], num_nodes) # in degrees
|
||||||
|
nodes = torch.nonzero(degrees < 1).view(-1)
|
||||||
|
nodes = [self.inverse_materials_mines_dict[torch.argmax(data.features[i]).item()+1] for i in nodes]
|
||||||
|
nodes = [n for n in nodes if n.split('_')[-1] != 'PLANKS']
|
||||||
|
return nodes
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.end_ts - self.start_ts
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.__iter_ts < self.end_ts:
|
||||||
|
if self.load_dialogue:
|
||||||
|
d = [x for x in self.dialogue_events if x[0] == self.__iter_ts]
|
||||||
|
l = [x for x in self.dialogue_act_labels if x[0] == self.__iter_ts]
|
||||||
|
d = d if d else None
|
||||||
|
l = l if l else None
|
||||||
|
else:
|
||||||
|
d = None
|
||||||
|
l = None
|
||||||
|
if self.use_dialogue_moves:
|
||||||
|
m = [x for x in self.dialogue_move_labels if x[0] == self.__iter_ts]
|
||||||
|
m = m if m else None
|
||||||
|
else:
|
||||||
|
m = None
|
||||||
|
if self.action_labels:
|
||||||
|
a = [x for x in self.action_labels if (x.TickIndex//30 + self.start_ts) >= self.__iter_ts]
|
||||||
|
if a:
|
||||||
|
try:
|
||||||
|
while not a[0].items:
|
||||||
|
a = a[1:]
|
||||||
|
al = self.materials.index(a[0].items[0]) if a else 0
|
||||||
|
except Exception:
|
||||||
|
print(a)
|
||||||
|
print(a[0])
|
||||||
|
print(a[0].items)
|
||||||
|
print(a[0].items[0])
|
||||||
|
exit()
|
||||||
|
at = a[0].TickIndex//30 + self.start_ts
|
||||||
|
an = int(a[0].Name[-1])
|
||||||
|
a = [(at,al,an)]
|
||||||
|
else:
|
||||||
|
a = [(self.__iter_ts, self.materials.index(self.plan['materials'][0]), 1)]
|
||||||
|
a = None
|
||||||
|
else:
|
||||||
|
if self.end_ts - self.__iter_ts < 10:
|
||||||
|
a = None
|
||||||
|
else:
|
||||||
|
a = None
|
||||||
|
if not a is None:
|
||||||
|
if not a[0][0] == self.__iter_ts:
|
||||||
|
a = None
|
||||||
|
q = [x for x in self.question_pairs if (x[0][1] == self.__iter_ts)]
|
||||||
|
q = q[0] if q else None
|
||||||
|
frame_idx = self.__iter_ts - self.start_ts
|
||||||
|
if self.load_third_person:
|
||||||
|
frames = self.third_pers_frames
|
||||||
|
elif self.load_player1:
|
||||||
|
frames = self.player1_pov_frames
|
||||||
|
elif self.load_player2:
|
||||||
|
frames = self.player2_pov_frames
|
||||||
|
else:
|
||||||
|
frames = np.array([0])
|
||||||
|
if len(frames) == 1:
|
||||||
|
f = np.zeros((self.img_h,self.img_w,3))
|
||||||
|
else:
|
||||||
|
if frame_idx < frames.shape[0]:
|
||||||
|
f = frames[frame_idx]
|
||||||
|
else:
|
||||||
|
f = np.zeros((self.img_h,self.img_w,3))
|
||||||
|
if self.do_upperbound:
|
||||||
|
if not q is None:
|
||||||
|
qnum = 0
|
||||||
|
base_rep = np.concatenate([
|
||||||
|
onehot(q[0][2],2),
|
||||||
|
onehot(q[0][3],2),
|
||||||
|
onehot(q[0][4][qnum][0]+1,2),
|
||||||
|
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||||
|
onehot(q[0][4][qnum][0]+1,2),
|
||||||
|
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||||
|
onehot(['YES','MAYBE','NO'].index(q[1][0][qnum])+1,3),
|
||||||
|
onehot(['YES','MAYBE','NO'].index(q[1][1][qnum])+1,3)
|
||||||
|
])
|
||||||
|
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||||
|
ToM6 = base_rep if self.ToM6 is not None else np.zeros(1024)
|
||||||
|
qnum = 1
|
||||||
|
base_rep = np.concatenate([
|
||||||
|
onehot(q[0][2],2),
|
||||||
|
onehot(q[0][3],2),
|
||||||
|
onehot(q[0][4][qnum][0]+1,2),
|
||||||
|
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||||
|
onehot(q[0][4][qnum][0]+1,2),
|
||||||
|
onehot(self.materials_dict[q[0][5][qnum][1]],len(self.materials_dict)),
|
||||||
|
onehot(['YES','MAYBE','NO'].index(q[1][0][qnum])+1,3),
|
||||||
|
onehot(['YES','MAYBE','NO'].index(q[1][1][qnum])+1,3)
|
||||||
|
])
|
||||||
|
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||||
|
ToM7 = base_rep if self.ToM7 is not None else np.zeros(1024)
|
||||||
|
qnum = 2
|
||||||
|
base_rep = np.concatenate([
|
||||||
|
onehot(q[0][2],2),
|
||||||
|
onehot(q[0][3],2),
|
||||||
|
onehot(q[0][4][qnum]+1,2),
|
||||||
|
onehot(q[0][4][qnum]+1,2),
|
||||||
|
onehot(self.materials_dict[q[1][0][qnum]] if q[1][0][qnum] in self.materials_dict else len(self.materials_dict)+1,len(self.materials_dict)+1),
|
||||||
|
onehot(self.materials_dict[q[1][1][qnum]] if q[1][1][qnum] in self.materials_dict else len(self.materials_dict)+1,len(self.materials_dict)+1)
|
||||||
|
])
|
||||||
|
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||||
|
ToM8 = base_rep if self.ToM8 is not None else np.zeros(1024)
|
||||||
|
else:
|
||||||
|
ToM6 = np.zeros(1024)
|
||||||
|
ToM7 = np.zeros(1024)
|
||||||
|
ToM8 = np.zeros(1024)
|
||||||
|
if not l is None:
|
||||||
|
base_rep = np.concatenate([
|
||||||
|
onehot(l[0][1],2),
|
||||||
|
onehot(l[0][2],len(self.dialogue_act_labels_dict))
|
||||||
|
])
|
||||||
|
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||||
|
DAct = base_rep if self.DAct is not None else np.zeros(1024)
|
||||||
|
else:
|
||||||
|
DAct = np.zeros(1024)
|
||||||
|
if not m is None:
|
||||||
|
base_rep = np.concatenate([
|
||||||
|
onehot(m[0][1],2),
|
||||||
|
onehot(m[0][2][0],len(self.dialogue_move_labels_dict)),
|
||||||
|
onehot(m[0][2][1],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||||
|
onehot(m[0][2][2],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||||
|
onehot(m[0][2][3],len(self.tools_dict) + len(self.materials_dict) + len(self.mines_dict)+1),
|
||||||
|
])
|
||||||
|
base_rep = np.concatenate([base_rep, np.zeros(1024-base_rep.shape[0])])
|
||||||
|
DMove = base_rep if self.DMove is not None else np.zeros(1024)
|
||||||
|
else:
|
||||||
|
DMove = np.zeros(1024)
|
||||||
|
else:
|
||||||
|
ToM6 = self.ToM6[frame_idx] if self.ToM6 is not None else np.zeros(1024)
|
||||||
|
ToM7 = self.ToM7[frame_idx] if self.ToM7 is not None else np.zeros(1024)
|
||||||
|
ToM8 = self.ToM8[frame_idx] if self.ToM8 is not None else np.zeros(1024)
|
||||||
|
DAct = self.DAct[frame_idx] if self.DAct is not None else np.zeros(1024)
|
||||||
|
DMove = self.DAct[frame_idx] if self.DMove is not None else np.zeros(1024)
|
||||||
|
intermediate = np.concatenate([ToM6,ToM7,ToM8,DAct,DMove])
|
||||||
|
if self.load_int0:
|
||||||
|
intermediate = np.zeros(1024*5)
|
||||||
|
intermediate[:1024] = self.int0_exp2_feats[frame_idx]
|
||||||
|
retval = ((self.__iter_ts,self.pov),d,l,q,f,a,intermediate,m)
|
||||||
|
self.__iter_ts += 1
|
||||||
|
return retval
|
||||||
|
self.__iter_ts = self.start_ts
|
||||||
|
raise StopIteration()
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __load_videos(self):
|
||||||
|
d = self.end_ts - self.start_ts
|
||||||
|
if self.load_third_person:
|
||||||
|
try:
|
||||||
|
self.third_pers_file = glob(os.path.join(self.game_path,'third*gif'))[0]
|
||||||
|
np_file = self.third_pers_file[:-3]+'npz'
|
||||||
|
if os.path.isfile(np_file):
|
||||||
|
self.third_pers_frames = np.load(np_file)['data']
|
||||||
|
else:
|
||||||
|
frames = imageio.get_reader(self.third_pers_file, '.gif')
|
||||||
|
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
|
||||||
|
if 'main' in self.game_path:
|
||||||
|
self.third_pers_frames = np.array([reshaper(f[95:4*95,250:-249,2::-1]) for f in frames])
|
||||||
|
else:
|
||||||
|
self.third_pers_frames = np.array([reshaper(f[-3*95:,250:-249,2::-1]) for f in frames])
|
||||||
|
print(np_file,end=' ')
|
||||||
|
np.savez_compressed(open(np_file,'wb'), data=self.third_pers_frames)
|
||||||
|
print('saved')
|
||||||
|
except Exception as e:
|
||||||
|
self.third_pers_frames = np.array([0])
|
||||||
|
if self.third_pers_frames.shape[0]//d < 10:
|
||||||
|
self.third_pov_frame_rate = 6
|
||||||
|
else:
|
||||||
|
if self.third_pers_frames.shape[0]//d < 20:
|
||||||
|
self.third_pov_frame_rate = 12
|
||||||
|
else:
|
||||||
|
if self.third_pers_frames.shape[0]//d < 45:
|
||||||
|
self.third_pov_frame_rate = 30
|
||||||
|
else:
|
||||||
|
self.third_pov_frame_rate = 60
|
||||||
|
self.third_pers_frames = self.third_pers_frames[::self.third_pov_frame_rate]
|
||||||
|
else:
|
||||||
|
self.third_pers_frames = np.array([0])
|
||||||
|
if self.load_player1:
|
||||||
|
try:
|
||||||
|
search_str = 'play2*gif' if self.flip_video else 'play1*gif'
|
||||||
|
self.player1_pov_file = glob(os.path.join(self.game_path,search_str))[0]
|
||||||
|
np_file = self.player1_pov_file[:-3]+'npz'
|
||||||
|
if os.path.isfile(np_file):
|
||||||
|
self.player1_pov_frames = np.load(np_file)['data']
|
||||||
|
else:
|
||||||
|
frames = imageio.get_reader(self.player1_pov_file, '.gif')
|
||||||
|
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
|
||||||
|
self.player1_pov_frames = np.array([reshaper(f[:,:,2::-1]) for f in frames])
|
||||||
|
print(np_file,end=' ')
|
||||||
|
np.savez_compressed(open(np_file,'wb'), data=self.player1_pov_frames)
|
||||||
|
print('saved')
|
||||||
|
except Exception as e:
|
||||||
|
self.player1_pov_frames = np.array([0])
|
||||||
|
if self.player1_pov_frames.shape[0]//d < 10:
|
||||||
|
self.player1_pov_frame_rate = 6
|
||||||
|
else:
|
||||||
|
if self.player1_pov_frames.shape[0]//d < 20:
|
||||||
|
self.player1_pov_frame_rate = 12
|
||||||
|
else:
|
||||||
|
if self.player1_pov_frames.shape[0]//d < 45:
|
||||||
|
self.player1_pov_frame_rate = 30
|
||||||
|
else:
|
||||||
|
self.player1_pov_frame_rate = 60
|
||||||
|
self.player1_pov_frames = self.player1_pov_frames[::self.player1_pov_frame_rate]
|
||||||
|
else:
|
||||||
|
self.player1_pov_frames = np.array([0])
|
||||||
|
if self.load_player2:
|
||||||
|
try:
|
||||||
|
search_str = 'play1*gif' if self.flip_video else 'play2*gif'
|
||||||
|
self.player2_pov_file = glob(os.path.join(self.game_path,search_str))[0]
|
||||||
|
np_file = self.player2_pov_file[:-3]+'npz'
|
||||||
|
if os.path.isfile(np_file):
|
||||||
|
self.player2_pov_frames = np.load(np_file)['data']
|
||||||
|
else:
|
||||||
|
frames = imageio.get_reader(self.player2_pov_file, '.gif')
|
||||||
|
reshaper = lambda x: cv2.resize(x,(self.img_h,self.img_w))
|
||||||
|
self.player2_pov_frames = np.array([reshaper(f[:,:,2::-1]) for f in frames])
|
||||||
|
print(np_file,end=' ')
|
||||||
|
np.savez_compressed(open(np_file,'wb'), data=self.player2_pov_frames)
|
||||||
|
print('saved')
|
||||||
|
except Exception as e:
|
||||||
|
self.player2_pov_frames = np.array([0])
|
||||||
|
if self.player2_pov_frames.shape[0]//d < 10:
|
||||||
|
self.player2_pov_frame_rate = 6
|
||||||
|
else:
|
||||||
|
if self.player2_pov_frames.shape[0]//d < 20:
|
||||||
|
self.player2_pov_frame_rate = 12
|
||||||
|
else:
|
||||||
|
if self.player2_pov_frames.shape[0]//d < 45:
|
||||||
|
self.player2_pov_frame_rate = 30
|
||||||
|
else:
|
||||||
|
self.player2_pov_frame_rate = 60
|
||||||
|
self.player2_pov_frames = self.player2_pov_frames[::self.player2_pov_frame_rate]
|
||||||
|
else:
|
||||||
|
self.player2_pov_frames = np.array([0])
|
||||||
|
|
||||||
|
def __parse_question_pairs(self):
|
||||||
|
question_dict = {}
|
||||||
|
for q in self.questions:
|
||||||
|
k = q[2][0][1] + q[2][1][1]
|
||||||
|
if not k in question_dict:
|
||||||
|
question_dict[k] = []
|
||||||
|
question_dict[k].append(q)
|
||||||
|
self.question_pairs = []
|
||||||
|
for k,v in question_dict.items():
|
||||||
|
if len(v) == 2:
|
||||||
|
if v[0][1]+v[1][1] == 3:
|
||||||
|
self.question_pairs.append(v)
|
||||||
|
else:
|
||||||
|
while len(v) > 1:
|
||||||
|
pair = []
|
||||||
|
pair.append(v.pop(0))
|
||||||
|
pair.append(v.pop(0))
|
||||||
|
while not pair[0][1]+pair[1][1] == 3:
|
||||||
|
if not v:
|
||||||
|
break
|
||||||
|
# print(game_path,pair)
|
||||||
|
pair.append(v.pop(0))
|
||||||
|
pair.pop(0)
|
||||||
|
if not v:
|
||||||
|
break
|
||||||
|
self.question_pairs.append(pair)
|
||||||
|
self.question_pairs = sorted(self.question_pairs, key=lambda x: x[0][0])
|
||||||
|
if self.load_player2 or self.pov==4:
|
||||||
|
self.question_pairs = [sorted(q, key=lambda x: x[1],reverse=True) for q in self.question_pairs]
|
||||||
|
else:
|
||||||
|
self.question_pairs = [sorted(q, key=lambda x: x[1]) for q in self.question_pairs]
|
||||||
|
self.question_pairs = [((a[0], b[0], a[1], b[1], a[2], b[2]), (a[3], b[3])) for a,b in self.question_pairs]
|
||||||
|
|
||||||
|
def __parse_dialogue(self):
|
||||||
|
self.dialogue_events = []
|
||||||
|
save_path = os.path.join(self.game_path,f'dialogue_{self.game_path.split("/")[-1]}.pkl')
|
||||||
|
if os.path.isfile(save_path):
|
||||||
|
self.dialogue_events = pickle.load(open( save_path, "rb" ))
|
||||||
|
return
|
||||||
|
for x in open(self.dialogue_file):
|
||||||
|
if '[Async Chat Thread' in x:
|
||||||
|
ts = list(map(int,x.split(' [')[0].strip('[]').split(':')))
|
||||||
|
ts = 3600*ts[0] + 60*ts[1] + ts[2]
|
||||||
|
player, event = x.strip().split('/INFO]: []<sledmcc')[1].split('> ',1)
|
||||||
|
event = event.lower()
|
||||||
|
event = ''.join([x if x in string.ascii_lowercase else f' {x} ' for x in event]).strip()
|
||||||
|
event = event.replace(' ',' ').replace(' ',' ')
|
||||||
|
player = int(player)
|
||||||
|
if GameParser.tokenizer is None:
|
||||||
|
GameParser.tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=True)
|
||||||
|
if self.model is None:
|
||||||
|
GameParser.model = BertModel.from_pretrained('bert-large-uncased', output_hidden_states=True).to(DEVICE)
|
||||||
|
encoded_dict = GameParser.tokenizer.encode_plus(
|
||||||
|
event, # Sentence to encode.
|
||||||
|
add_special_tokens=True, # Add '[CLS]' and '[SEP]'
|
||||||
|
return_tensors='pt', # Return pytorch tensors.
|
||||||
|
)
|
||||||
|
token_ids = encoded_dict['input_ids'].to(DEVICE)
|
||||||
|
segment_ids = torch.ones(token_ids.size()).long().to(DEVICE)
|
||||||
|
GameParser.model.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = GameParser.model(input_ids=token_ids, token_type_ids=segment_ids)
|
||||||
|
outputs = outputs[1][0].cpu().data.numpy()
|
||||||
|
self.dialogue_events.append((ts,player,event,outputs))
|
||||||
|
pickle.dump(self.dialogue_events, open( save_path, "wb" ))
|
||||||
|
print(f'Saved to {save_path}',flush=True)
|
||||||
|
|
||||||
|
def __parse_questions(self):
|
||||||
|
self.questions = []
|
||||||
|
for x in open(self.questions_file):
|
||||||
|
if x[0] == '#':
|
||||||
|
ts, qs = x.strip().split(' Number of records inserted: 1 # player')
|
||||||
|
ts = list(map(int,ts.split(' ')[5].split(':')))
|
||||||
|
ts = 3600*ts[0] + 60*ts[1] + ts[2]
|
||||||
|
player = int(qs[0])
|
||||||
|
questions = qs[2:].split(';')
|
||||||
|
answers =[x[7:] for x in questions[3:]]
|
||||||
|
questions = [x[9:].split(' ') for x in questions[:3]]
|
||||||
|
questions[0] = (int(questions[0][0] == 'Have'), questions[0][-3])
|
||||||
|
questions[1] = (int(questions[1][2] == 'know'), questions[1][-1])
|
||||||
|
questions[2] = int(questions[2][1] == 'are')
|
||||||
|
|
||||||
|
self.questions.append((ts,player,questions,answers))
|
||||||
|
def __parse_start_end(self):
|
||||||
|
self.start_ts = [x.strip() for x in open(self.dialogue_file) if 'THEY ARE PLAYER' in x][1]
|
||||||
|
self.start_ts = list(map(int,self.start_ts.split('] [')[0][1:].split(':')))
|
||||||
|
self.start_ts = 3600*self.start_ts[0] + 60*self.start_ts[1] + self.start_ts[2]
|
||||||
|
try:
|
||||||
|
self.start_ts = max(self.start_ts, self.questions[0][0]-75)
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
self.end_ts = [x.strip() for x in open(self.dialogue_file) if 'Stopping' in x]
|
||||||
|
if self.end_ts:
|
||||||
|
self.end_ts = self.end_ts[0]
|
||||||
|
self.end_ts = list(map(int,self.end_ts.split('] [')[0][1:].split(':')))
|
||||||
|
self.end_ts = 3600*self.end_ts[0] + 60*self.end_ts[1] + self.end_ts[2]
|
||||||
|
else:
|
||||||
|
self.end_ts = self.dialogue_events[-1][0]
|
||||||
|
try:
|
||||||
|
self.end_ts = max(self.end_ts, self.questions[-1][0]) + 1
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __load_dialogue_act_labels(self):
|
||||||
|
file_name = 'config/dialogue_act_labels.json'
|
||||||
|
if not os.path.isfile(file_name):
|
||||||
|
files = sorted(glob('/home/*/MCC/*done.txt'))
|
||||||
|
dialogue_act_dict = {}
|
||||||
|
for file in files:
|
||||||
|
game_str = ''
|
||||||
|
for line in open(file):
|
||||||
|
line = line.strip()
|
||||||
|
if '_logs/' in line:
|
||||||
|
game_str = line
|
||||||
|
else:
|
||||||
|
if line:
|
||||||
|
line = line.split()
|
||||||
|
key = f'{game_str}#{line[0]}'
|
||||||
|
dialogue_act_dict[key] = line[-1]
|
||||||
|
json.dump(dialogue_act_dict,open(file_name,'w'), indent=4)
|
||||||
|
self.dialogue_act_dict = json.load(open(file_name))
|
||||||
|
self.dialogue_act_labels_dict = {l : i for i, l in enumerate(sorted(list(set(self.dialogue_act_dict.values()))))}
|
||||||
|
self.dialogue_act_bias = {l : sum([int(x==l) for x in self.dialogue_act_dict.values()]) for l in self.dialogue_act_labels_dict.keys()}
|
||||||
|
json.dump(self.dialogue_act_labels_dict,open('config/dialogue_act_label_names.json','w'), indent=4)
|
||||||
|
|
||||||
|
def __assign_dialogue_act_labels(self):
|
||||||
|
log_file = glob('/'.join([self.game_path,'mcc*log']))[0].split('mindcraft/')[1]
|
||||||
|
self.dialogue_act_labels = []
|
||||||
|
for emb in self.dialogue_events:
|
||||||
|
ts = emb[0]
|
||||||
|
h = ts//3600
|
||||||
|
m = (ts%3600)//60
|
||||||
|
s = ts%60
|
||||||
|
key = f'{log_file}#[{h:02d}:{m:02d}:{s:02d}]:{emb[1]}>'
|
||||||
|
self.dialogue_act_labels.append((emb[0],emb[1],self.dialogue_act_labels_dict[self.dialogue_act_dict[key]]))
|
||||||
|
|
||||||
|
def __load_dialogue_move_labels(self):
|
||||||
|
file_name = "config/dialogue_move_labels.json"
|
||||||
|
dialogue_move_dict = {}
|
||||||
|
if not os.path.isfile(file_name):
|
||||||
|
file_text = ''
|
||||||
|
dialogue_moves = set()
|
||||||
|
for line in open("XXX"):
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if line[0] == '#':
|
||||||
|
continue
|
||||||
|
if line[0] == '[':
|
||||||
|
tag_text = glob(f'data/*/*/mcc_{file_text}.log')[0].split('/',1)[-1]
|
||||||
|
key = f'{tag_text}#{line.split()[0]}'
|
||||||
|
value = line.split()[-1].split('#')
|
||||||
|
if len(value) < 4:
|
||||||
|
value += ['IGNORE']*(4-len(value))
|
||||||
|
dialogue_moves.add(value[0])
|
||||||
|
value = '#'.join(value)
|
||||||
|
dialogue_move_dict[key] = value
|
||||||
|
else:
|
||||||
|
file_text = line
|
||||||
|
dialogue_moves = sorted(list(dialogue_moves))
|
||||||
|
|
||||||
|
json.dump(dialogue_move_dict,open(file_name,'w'), indent=4)
|
||||||
|
self.dialogue_move_dict = json.load(open(file_name))
|
||||||
|
self.dialogue_move_labels_dict = {l : i for i, l in enumerate(sorted(list(set([lbl.split('#')[0] for lbl in self.dialogue_move_dict.values()]))))}
|
||||||
|
self.dialogue_move_bias = {l : sum([int(x==l) for x in self.dialogue_move_dict.values()]) for l in self.dialogue_move_labels_dict.keys()}
|
||||||
|
json.dump(self.dialogue_move_labels_dict,open('config/dialogue_move_label_names.json','w'), indent=4)
|
||||||
|
|
||||||
|
def __assign_dialogue_move_labels(self):
|
||||||
|
log_file = glob('/'.join([self.game_path,'mcc*log']))[0].split('mindcraft/')[1]
|
||||||
|
self.dialogue_move_labels = []
|
||||||
|
for emb in self.dialogue_events:
|
||||||
|
ts = emb[0]
|
||||||
|
h = ts//3600
|
||||||
|
m = (ts%3600)//60
|
||||||
|
s = ts%60
|
||||||
|
key = f'{log_file}#[{h:02d}:{m:02d}:{s:02d}]:{emb[1]}>'
|
||||||
|
move = self.dialogue_move_dict[key].split('#')
|
||||||
|
move[0] = self.dialogue_move_labels_dict[move[0]]
|
||||||
|
for i,m in enumerate(move[1:]):
|
||||||
|
if m == 'IGNORE':
|
||||||
|
move[i+1] = 0
|
||||||
|
elif m in self.materials_dict:
|
||||||
|
move[i+1] = self.materials_dict[m]
|
||||||
|
elif m in self.mines_dict:
|
||||||
|
move[i+1] = self.mines_dict[m] + len(self.materials_dict)
|
||||||
|
elif m in self.tools_dict:
|
||||||
|
move[i+1] = self.tools_dict[m] + len(self.materials_dict) + len(self.mines_dict)
|
||||||
|
else:
|
||||||
|
print(move)
|
||||||
|
exit()
|
||||||
|
self.dialogue_move_labels.append((emb[0],emb[1],move))
|
||||||
|
|
||||||
|
def __load_replay_data(self):
|
||||||
|
self.actions = None
|
||||||
|
|
||||||
|
def __load_intermediate(self):
|
||||||
|
if self.intermediate > 15:
|
||||||
|
self.do_upperbound = True
|
||||||
|
else:
|
||||||
|
self.do_upperbound = False
|
||||||
|
if self.pov in [1,2]:
|
||||||
|
self.ToM6 = np.load(glob(f'{self.game_path}/intermediate_ToM6*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||||
|
self.intermediate = self.intermediate // 2
|
||||||
|
self.ToM7 = np.load(glob(f'{self.game_path}/intermediate_ToM7*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||||
|
self.intermediate = self.intermediate // 2
|
||||||
|
self.ToM8 = np.load(glob(f'{self.game_path}/intermediate_ToM8*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||||
|
self.intermediate = self.intermediate // 2
|
||||||
|
self.DAct = np.load(glob(f'{self.game_path}/intermediate_DAct*player{self.pov}.npz')[0])['data'] if self.intermediate % 2 else None
|
||||||
|
self.intermediate = self.intermediate // 2
|
||||||
|
self.DMove = None
|
||||||
|
else:
|
||||||
|
self.ToM6 = None
|
||||||
|
self.ToM7 = None
|
||||||
|
self.ToM8 = None
|
||||||
|
self.DAct = None
|
||||||
|
self.DMove = None
|
||||||
|
|
||||||
|
def __load_int0_feats(self):
|
||||||
|
self.int0_exp2_feats = np.load(glob(f'{self.game_path}/int0_exp2*player{self.pov}.npz')[0])['data']
|
||||||
|
# self.int0_exp3_feats = np.load(np.load(glob(f'{self.game_path}/int0_exp3*player{self.pov}.npz')[0])['data'])
|
||||||
|
|
BIN
src/models/.DS_Store
vendored
Normal file
BIN
src/models/.DS_Store
vendored
Normal file
Binary file not shown.
0
src/models/__init__.py
Normal file
0
src/models/__init__.py
Normal file
207
src/models/losses.py
Normal file
207
src/models/losses.py
Normal file
|
@ -0,0 +1,207 @@
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import numpy as np
|
||||||
|
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
|
||||||
|
|
||||||
|
def onehot(x,n):
|
||||||
|
retval = np.zeros(n)
|
||||||
|
if x > 0:
|
||||||
|
retval[x-1] = 1
|
||||||
|
return retval
|
||||||
|
|
||||||
|
class PlanLoss(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(PlanLoss, self).__init__()
|
||||||
|
def getWeights(self, output, target):
|
||||||
|
# return 1
|
||||||
|
f1 = (1+5*torch.stack([2-torch.sum(target.reshape(-1,21,21),dim=-1)]*21,dim=-1)).reshape(-1,21*21)
|
||||||
|
f2 = 100*target + 1
|
||||||
|
return (f1+f2)/60
|
||||||
|
exit(0)
|
||||||
|
# print(max(torch.sum(target.reshape(21,21),dim=-1)))
|
||||||
|
return (target*torch.sum(target,dim=-1) + 1)
|
||||||
|
def MSELoss(self, output, target):
|
||||||
|
retval = (output - target)**2
|
||||||
|
retval *= self.getWeights(output,target)
|
||||||
|
return torch.mean(retval)
|
||||||
|
def BCELoss(self, output, target, loss_mask=None):
|
||||||
|
mask_factor = torch.ones(target.shape).to(output.device)
|
||||||
|
if loss_mask is not None:
|
||||||
|
loss_mask = loss_mask.reshape(-1,21,21)
|
||||||
|
mask_factor = mask_factor.reshape(-1,21,21)
|
||||||
|
# print(mask_factor.shape,loss_mask.shape,output.shape,target.shape)
|
||||||
|
for idx, tgt in enumerate(loss_mask):
|
||||||
|
for jdx, tgt_node in enumerate(tgt):
|
||||||
|
if sum(tgt_node) == 0:
|
||||||
|
mask_factor[idx,jdx] *= 0
|
||||||
|
|
||||||
|
# print(loss_mask[0].data.cpu().numpy())
|
||||||
|
# print(mask_factor[0].data.cpu().numpy())
|
||||||
|
# print()
|
||||||
|
# print(loss_mask[45].data.cpu().numpy())
|
||||||
|
# print(mask_factor[45].data.cpu().numpy())
|
||||||
|
# print()
|
||||||
|
# print(loss_mask[-1].data.cpu().numpy())
|
||||||
|
# print(mask_factor[-1].data.cpu().numpy())
|
||||||
|
# print()
|
||||||
|
|
||||||
|
|
||||||
|
loss_mask = loss_mask.reshape(-1,21*21)
|
||||||
|
mask_factor = mask_factor.reshape(-1,21*21)
|
||||||
|
# print(loss_mask.shape, target.shape, mask_factor.shape)
|
||||||
|
# exit()
|
||||||
|
|
||||||
|
factor = (10 if target.shape[-1]==441 else 1)# * torch.sum(target,dim=-1)+1
|
||||||
|
retval = -1 * mask_factor * (factor * target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
|
||||||
|
factor = torch.stack([torch.sum(target,dim=-1)+1]*target.shape[-1],dim=-1)
|
||||||
|
return torch.mean(factor*retval)
|
||||||
|
return torch.mean(retval)
|
||||||
|
def forward(self, output, target, loss_mask=None):
|
||||||
|
return self.BCELoss(output,target,loss_mask) + 0.01*torch.sum(output - 1/21)
|
||||||
|
# return self.MSELoss(output,target)
|
||||||
|
|
||||||
|
class DialogueActLoss(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(DialogueActLoss, self).__init__()
|
||||||
|
self.bias = torch.tensor([289,51,45,57,14,12,1,113,6,264,27,63,22,66,2,761,129,163,5]).float()
|
||||||
|
self.bias = max(self.bias) - self.bias + 1
|
||||||
|
self.bias /= torch.sum(self.bias)
|
||||||
|
self.bias = 1-self.bias
|
||||||
|
# self.bias *= self.bias
|
||||||
|
def BCELoss(self, output, target):
|
||||||
|
target = torch.stack([torch.tensor(onehot(x + 1,19)).long() for x in target]).to(output.device)
|
||||||
|
retval = -1 * (target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
|
||||||
|
retval *= torch.stack([self.bias] * output.shape[0]).to(output.device)
|
||||||
|
# print(output)
|
||||||
|
# print(target)
|
||||||
|
# print(retval)
|
||||||
|
# print(torch.mean(retval))
|
||||||
|
# exit()
|
||||||
|
return torch.mean(retval)
|
||||||
|
def forward(self, output, target):
|
||||||
|
return self.BCELoss(output,target)
|
||||||
|
# return self.MSELoss(output,target)
|
||||||
|
|
||||||
|
class DialogueMoveLoss(nn.Module):
|
||||||
|
def __init__(self, device):
|
||||||
|
super(DialogueMoveLoss, self).__init__()
|
||||||
|
# self.bias = torch.tensor([289,51,45,57,14,12,1,113,6,264,27,63,22,66,2,761,129,163,5]).float()
|
||||||
|
# self.bias = max(self.bias) - self.bias + 1
|
||||||
|
# self.bias /= torch.sum(self.bias)
|
||||||
|
# self.bias = 1-self.bias
|
||||||
|
move_weights = torch.tensor(np.array([202, 34, 34, 48, 4, 2, 420, 10, 54, 1, 10, 11, 30, 28, 14, 2, 16, 6, 2, 86, 4, 12, 28, 2, 2, 16, 12, 14, 4, 1, 12, 258, 12, 26, 2])).float().to(device)
|
||||||
|
move_weights = 1+ max(move_weights) - move_weights
|
||||||
|
self.loss1 = CrossEntropyLoss(weight=move_weights)
|
||||||
|
zero_bias = 0.773
|
||||||
|
num_classes = 40
|
||||||
|
|
||||||
|
weight = torch.tensor(np.array([50 if not x else 1 for x in range(num_classes)])).float().to(device)
|
||||||
|
weight = 1+ max(weight) - weight
|
||||||
|
self.loss2 = CrossEntropyLoss(weight=weight)
|
||||||
|
# self.bias *= self.bias
|
||||||
|
def BCELoss(self, output, target,zero_bias):
|
||||||
|
# # print(output.shape,target.shape)
|
||||||
|
# bias = torch.tensor(np.array([1 if t else zero_bias for t in target])).to(output.device)
|
||||||
|
# target = torch.stack([torch.tensor(onehot(x,output.shape[-1])).long() for x in target]).to(output.device)
|
||||||
|
# # print(target.shape, bias.shape, bias)
|
||||||
|
|
||||||
|
# retval = -1 * (target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
|
||||||
|
# retval = torch.mean(retval,-1)
|
||||||
|
|
||||||
|
# # print(retval.shape)
|
||||||
|
# retval *= bias
|
||||||
|
# # retval *= torch.stack([self.bias] * output.shape[0]).to(output.device)
|
||||||
|
# # print(output)
|
||||||
|
# # print(target)
|
||||||
|
# # print(retval)
|
||||||
|
# # print(torch.mean(retval))
|
||||||
|
# # exit()
|
||||||
|
# # retval = self.loss(output,target)
|
||||||
|
# return torch.mean(retval) # retval #
|
||||||
|
# weight = [zero_bias if x else (1-zero_bias)/(output.shape[-1]-1) for x in range(output.shape[-1])]
|
||||||
|
retval = self.loss2(output,target) if zero_bias else self.loss1(output,target)
|
||||||
|
return retval #
|
||||||
|
def forward(self, output, target):
|
||||||
|
o1, o2, o3, o4 = output
|
||||||
|
t1, t2, t3, t4 = target
|
||||||
|
|
||||||
|
# print(t2,t2.shape, o2.shape)
|
||||||
|
|
||||||
|
# if sum(t2):
|
||||||
|
# o2, t2 = zip(*[(a,b) for a,b in zip(o2,t2) if b])
|
||||||
|
# o2 = torch.stack(o2)
|
||||||
|
# t2 = torch.stack(t2)
|
||||||
|
# if sum(t3):
|
||||||
|
# o3, t3 = zip(*[(a,b) for a,b in zip(o3,t3) if b])
|
||||||
|
# o3 = torch.stack(o3)
|
||||||
|
# t3 = torch.stack(t3)
|
||||||
|
# if sum(t4):
|
||||||
|
# o4, t4 = zip(*[(a,b) for a,b in zip(o4,t4) if b])
|
||||||
|
# o4 = torch.stack(o4)
|
||||||
|
# t4 = torch.stack(t4)
|
||||||
|
|
||||||
|
# print(t2,t2.shape, o2.shape)
|
||||||
|
# exit()
|
||||||
|
|
||||||
|
retval = sum([
|
||||||
|
1*self.BCELoss(output[0],target[0],0),
|
||||||
|
0*self.BCELoss(output[1],target[1],1),
|
||||||
|
0*self.BCELoss(output[2],target[2],1),
|
||||||
|
0*self.BCELoss(output[3],target[3],1)
|
||||||
|
])
|
||||||
|
return retval #sum([fact*self.BCELoss(o,t,zbias) for fact,zbias,o,t in zip([1,0,0,0],[0,1,1,1],output,target)])
|
||||||
|
# return self.MSELoss(output,target)
|
||||||
|
|
||||||
|
class DialoguePredLoss(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(DialoguePredLoss, self).__init__()
|
||||||
|
self.bias = torch.tensor([289,51,45,57,14,12,1,113,6,264,27,63,22,66,2,761,129,163,5,0]).float()
|
||||||
|
self.bias[-1] = 1460#2 * torch.sum(self.bias) // 3
|
||||||
|
self.bias = max(self.bias) - self.bias + 1
|
||||||
|
self.bias /= torch.sum(self.bias)
|
||||||
|
self.bias = 1-self.bias
|
||||||
|
# self.bias *= self.bias
|
||||||
|
def BCELoss(self, output, target):
|
||||||
|
target = torch.stack([torch.tensor(onehot(x + 1,20)).long() for x in target]).to(output.device)
|
||||||
|
retval = -1 * (target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
|
||||||
|
retval *= torch.stack([self.bias] * output.shape[0]).to(output.device)
|
||||||
|
# print(output)
|
||||||
|
# print(target)
|
||||||
|
# print(retval)
|
||||||
|
# print(torch.mean(retval))
|
||||||
|
# exit()
|
||||||
|
return torch.mean(retval)
|
||||||
|
def forward(self, output, target):
|
||||||
|
return self.BCELoss(output,target)
|
||||||
|
# return self.MSELoss(output,target)
|
||||||
|
|
||||||
|
class ActionLoss(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(ActionLoss, self).__init__()
|
||||||
|
self.bias1 = torch.tensor([134,1370,154,128,220,166,46,76,106,78,88,124,102,120,276,122,112,106,44,174,20]).float()
|
||||||
|
# self.bias[-1] = 1460#2 * torch.sum(self.bias) // 3
|
||||||
|
# self.bias1 = torch.ones(21).float()
|
||||||
|
self.bias1 = max(self.bias1) - self.bias1 + 1
|
||||||
|
self.bias1 /= torch.sum(self.bias1)
|
||||||
|
self.bias1 = 1-self.bias1
|
||||||
|
self.bias2 = torch.tensor([1168,1310]).float()
|
||||||
|
# self.bias2[-1] = 1460#2 * torch.sum(self.bias) // 3
|
||||||
|
# self.bias2 = torch.ones(21).float()
|
||||||
|
self.bias2 = max(self.bias2) - self.bias2 + 1
|
||||||
|
self.bias2 /= torch.sum(self.bias2)
|
||||||
|
self.bias2 = 1-self.bias2
|
||||||
|
# self.bias *= self.bias
|
||||||
|
def BCELoss(self, output, target):
|
||||||
|
# target = torch.stack([torch.tensor(onehot(x + 1,20)).long() for x in target]).to(output.device)
|
||||||
|
retval = -1 * (target * torch.log(1e-6+output) + (1-target) * torch.log(1e-6+1-output))
|
||||||
|
# print(self.bias1.shape,self.bias2.shape,output.shape[-1])
|
||||||
|
retval *= torch.stack([self.bias2 if output.shape[-1]==2 else self.bias1] * output.shape[0]).to(output.device)
|
||||||
|
# print(output)
|
||||||
|
# print(target)
|
||||||
|
# print(retval)
|
||||||
|
# print(torch.mean(retval))
|
||||||
|
# exit()
|
||||||
|
return torch.mean(retval)
|
||||||
|
def forward(self, output, target):
|
||||||
|
return self.BCELoss(output,target)
|
||||||
|
# return self.MSELoss(output,target)
|
205
src/models/model_with_dialogue_moves.py
Executable file
205
src/models/model_with_dialogue_moves.py
Executable file
|
@ -0,0 +1,205 @@
|
||||||
|
import sys, torch, random
|
||||||
|
from numpy.core.fromnumeric import reshape
|
||||||
|
import torch.nn as nn, numpy as np
|
||||||
|
from src.data.game_parser import DEVICE
|
||||||
|
|
||||||
|
def onehot(x,n):
|
||||||
|
retval = np.zeros(n)
|
||||||
|
if x > 0:
|
||||||
|
retval[x-1] = 1
|
||||||
|
return retval
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, seq_model_type=0,device=DEVICE):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.device = device
|
||||||
|
print("model set to device", self.device)
|
||||||
|
|
||||||
|
my_rnn = lambda i,o: nn.GRU(i,o)
|
||||||
|
#my_rnn = lambda i,o: nn.LSTM(i,o)
|
||||||
|
|
||||||
|
plan_emb_in = 81
|
||||||
|
plan_emb_out = 32
|
||||||
|
q_emb = 100
|
||||||
|
|
||||||
|
self.plan_embedder0 = my_rnn(plan_emb_in,plan_emb_out)
|
||||||
|
self.plan_embedder1 = my_rnn(plan_emb_in,plan_emb_out)
|
||||||
|
self.plan_embedder2 = my_rnn(plan_emb_in,plan_emb_out)
|
||||||
|
|
||||||
|
# self.dialogue_listener = my_rnn(1126,768)
|
||||||
|
dlist_hidden = 1024
|
||||||
|
frame_emb = 512
|
||||||
|
self.move_emb = 157
|
||||||
|
drnn_in = 1024 + 2 + q_emb + frame_emb + self.move_emb
|
||||||
|
# drnn_in = 1024 + 2
|
||||||
|
|
||||||
|
# my_rnn = lambda i,o: nn.GRU(i,o)
|
||||||
|
my_rnn = lambda i,o: nn.LSTM(i,o)
|
||||||
|
|
||||||
|
if seq_model_type==0:
|
||||||
|
self.dialogue_listener_rnn = nn.GRU(drnn_in,dlist_hidden)
|
||||||
|
self.dialogue_listener = lambda x: \
|
||||||
|
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||||
|
elif seq_model_type==1:
|
||||||
|
self.dialogue_listener_rnn = nn.LSTM(drnn_in,dlist_hidden)
|
||||||
|
self.dialogue_listener = lambda x: \
|
||||||
|
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||||
|
elif seq_model_type==2:
|
||||||
|
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0],x.shape[0]),diagonal=1).bool().to(self.device)
|
||||||
|
sincos_fun = lambda x:torch.transpose(torch.stack([
|
||||||
|
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
|
||||||
|
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
|
||||||
|
]),0,1).reshape(-1,1,2)
|
||||||
|
self.dialogue_listener_lin1 = nn.Linear(drnn_in,dlist_hidden-2)
|
||||||
|
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
|
||||||
|
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x,x,x,attn_mask=mask_fun(x))
|
||||||
|
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
|
||||||
|
sincos_fun(x.shape[0]).float().to(self.device),
|
||||||
|
self.dialogue_listener_lin1(x).reshape(-1,1,dlist_hidden-2)
|
||||||
|
], axis=-1))[0]
|
||||||
|
else:
|
||||||
|
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
conv_block = lambda i,o,k,p,s: nn.Sequential(
|
||||||
|
nn.Conv2d( i, o, k, padding=p, stride=s),
|
||||||
|
nn.BatchNorm2d(o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.BatchNorm2d(o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
conv_block( 3, 8, 3, 1, 1),
|
||||||
|
# conv_block( 3, 8, 5, 2, 2),
|
||||||
|
conv_block( 8, 32, 5, 2, 2),
|
||||||
|
conv_block( 32, frame_emb//4, 5, 2, 2),
|
||||||
|
nn.Conv2d( frame_emb//4, frame_emb, 3),nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
qlayer = lambda i,o : nn.Sequential(
|
||||||
|
nn.Linear(i,512),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.Linear(512,o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
# nn.Softmax(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
q_in_size = 3*plan_emb_out+dlist_hidden+q_emb
|
||||||
|
|
||||||
|
self.q01 = qlayer(q_in_size,2)
|
||||||
|
self.q02 = qlayer(q_in_size,2)
|
||||||
|
self.q03 = qlayer(q_in_size,2)
|
||||||
|
|
||||||
|
self.q11 = qlayer(q_in_size,3)
|
||||||
|
self.q12 = qlayer(q_in_size,3)
|
||||||
|
self.q13 = qlayer(q_in_size,22)
|
||||||
|
|
||||||
|
self.q21 = qlayer(q_in_size,3)
|
||||||
|
self.q22 = qlayer(q_in_size,3)
|
||||||
|
self.q23 = qlayer(q_in_size,22)
|
||||||
|
|
||||||
|
def forward(self,game,global_plan=False, player_plan=False, intermediate=False):
|
||||||
|
retval = []
|
||||||
|
|
||||||
|
l = list(game)
|
||||||
|
_,d,_,q,f,_,_,m = zip(*list(game))
|
||||||
|
|
||||||
|
|
||||||
|
parse_move = lambda m: np.concatenate([
|
||||||
|
onehot(m[0][1], 2),
|
||||||
|
onehot(m[0][2][0]+1, len(game.dialogue_move_labels_dict)),
|
||||||
|
onehot(m[0][2][1]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1),
|
||||||
|
onehot(m[0][2][2]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1),
|
||||||
|
onehot(m[0][2][3]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1)
|
||||||
|
])
|
||||||
|
# print(2+len(game.dialogue_move_labels_dict)+3*(len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1))
|
||||||
|
# print(len(game.dialogue_move_labels_dict))
|
||||||
|
m = np.stack([np.zeros(self.move_emb) if move is None else parse_move(move) for move in m])
|
||||||
|
|
||||||
|
|
||||||
|
h = None
|
||||||
|
f = np.array(f, dtype=np.uint8)
|
||||||
|
# f = torch.tensor(f).permute(0,3,1,2).float().to(self.device)
|
||||||
|
# flt_lst = [(a,b) for a,b in zip(d,q) if (not a is None) or (not b is None)]
|
||||||
|
# if not flt_lst:
|
||||||
|
# return []
|
||||||
|
# d,q = zip(*flt_lst)
|
||||||
|
d = np.stack([np.concatenate(([int(x[0][1]==2),int(x[0][1]==1)],x[0][-1])) if not x is None else np.zeros(1026) for x in d])
|
||||||
|
def parse_q(q):
|
||||||
|
if not q is None:
|
||||||
|
q ,l = q
|
||||||
|
q = np.concatenate([
|
||||||
|
onehot(q[2],2),
|
||||||
|
onehot(q[3],2),
|
||||||
|
onehot(q[4][0][0]+1,2),
|
||||||
|
onehot(game.materials_dict[q[4][0][1]],len(game.materials_dict)),
|
||||||
|
onehot(q[4][1][0]+1,2),
|
||||||
|
onehot(game.materials_dict[q[4][1][1]],len(game.materials_dict)),
|
||||||
|
onehot(q[4][2]+1,2),
|
||||||
|
onehot(q[5][0][0]+1,2),
|
||||||
|
onehot(game.materials_dict[q[5][0][1]],len(game.materials_dict)),
|
||||||
|
onehot(q[5][1][0]+1,2),
|
||||||
|
onehot(game.materials_dict[q[5][1][1]],len(game.materials_dict)),
|
||||||
|
onehot(q[5][2]+1,2)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
q = np.zeros(100)
|
||||||
|
l = None
|
||||||
|
return q, l
|
||||||
|
try:
|
||||||
|
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
|
||||||
|
sel2 = 1 - sel1
|
||||||
|
except Exception as e:
|
||||||
|
sel1 = 0
|
||||||
|
sel2 = 0
|
||||||
|
q = [parse_q(x) for x in q]
|
||||||
|
q, l = zip(*q)
|
||||||
|
q = np.stack(q)
|
||||||
|
|
||||||
|
if not global_plan and not player_plan:
|
||||||
|
plan_emb = torch.cat([
|
||||||
|
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||||
|
])
|
||||||
|
plan_emb = 0*plan_emb
|
||||||
|
elif global_plan:
|
||||||
|
plan_emb = torch.cat([
|
||||||
|
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
plan_emb = torch.cat([
|
||||||
|
0*self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
sel1*self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
sel2*self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||||
|
])
|
||||||
|
|
||||||
|
u = torch.cat((
|
||||||
|
torch.tensor(d).float().to(self.device),
|
||||||
|
torch.tensor(q).float().to(self.device),
|
||||||
|
self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512),
|
||||||
|
torch.tensor(m).float().to(self.device)
|
||||||
|
),axis=-1)
|
||||||
|
u = u.float().to(self.device)
|
||||||
|
|
||||||
|
y = self.dialogue_listener(u)
|
||||||
|
y = y.reshape(-1,y.shape[-1])
|
||||||
|
|
||||||
|
if intermediate:
|
||||||
|
return y
|
||||||
|
|
||||||
|
if all([x is None for x in l]):
|
||||||
|
return []
|
||||||
|
|
||||||
|
fun_lst = [self.q01,self.q02,self.q03,self.q11,self.q12,self.q13,self.q21,self.q22,self.q23]
|
||||||
|
fun = lambda x: [f(x) for f in fun_lst]
|
||||||
|
|
||||||
|
retval = [(_l,fun(torch.cat((plan_emb,torch.tensor(_q).float().to(self.device),_y)))) for _y, _q, _l in zip(y,q,l)if not _l is None]
|
||||||
|
return retval
|
226
src/models/model_with_dialogue_moves_graphs.py
Normal file
226
src/models/model_with_dialogue_moves_graphs.py
Normal file
|
@ -0,0 +1,226 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn, numpy as np
|
||||||
|
from src.data.game_parser import DEVICE
|
||||||
|
from torch_geometric.nn import GATv2Conv, MeanAggregation
|
||||||
|
|
||||||
|
|
||||||
|
def onehot(x,n):
|
||||||
|
retval = np.zeros(n)
|
||||||
|
if x > 0:
|
||||||
|
retval[x-1] = 1
|
||||||
|
return retval
|
||||||
|
|
||||||
|
|
||||||
|
class PlanGraphEmbedder(nn.Module):
|
||||||
|
def __init__(self, h_dim, dropout=0.0, heads=4):
|
||||||
|
super().__init__()
|
||||||
|
self.proj_x = nn.Sequential(
|
||||||
|
nn.Linear(27, h_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
self.proj_edge_attr = nn.Sequential(
|
||||||
|
nn.Linear(12, h_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
self.conv1 = GATv2Conv(h_dim, h_dim, heads=heads, edge_dim=h_dim)
|
||||||
|
self.conv2 = GATv2Conv(h_dim*heads, h_dim, heads=1, edge_dim=h_dim)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.pool = MeanAggregation()
|
||||||
|
|
||||||
|
def forward(self, data):
|
||||||
|
x, edge_index, edge_attr = data.features.to(DEVICE), data.edge_index.to(DEVICE), data.tool.to(DEVICE)
|
||||||
|
x = self.proj_x(x)
|
||||||
|
edge_attr = self.proj_edge_attr(edge_attr)
|
||||||
|
x = self.conv1(x, edge_index, edge_attr)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.conv2(x, edge_index, edge_attr)
|
||||||
|
x = self.pool(x)
|
||||||
|
return x.squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, seq_model_type=0,device=DEVICE):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.device = device
|
||||||
|
print("model set to device", self.device)
|
||||||
|
|
||||||
|
plan_emb_out = 32*3
|
||||||
|
q_emb = 100
|
||||||
|
|
||||||
|
self.plan_embedder0 = PlanGraphEmbedder(plan_emb_out)
|
||||||
|
self.plan_embedder1 = PlanGraphEmbedder(plan_emb_out)
|
||||||
|
self.plan_embedder2 = PlanGraphEmbedder(plan_emb_out)
|
||||||
|
|
||||||
|
dlist_hidden = 1024
|
||||||
|
frame_emb = 512
|
||||||
|
self.move_emb = 157
|
||||||
|
drnn_in = 1024 + 2 + q_emb + frame_emb + self.move_emb
|
||||||
|
|
||||||
|
if seq_model_type==0:
|
||||||
|
self.dialogue_listener_rnn = nn.GRU(drnn_in,dlist_hidden)
|
||||||
|
self.dialogue_listener = lambda x: \
|
||||||
|
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||||
|
elif seq_model_type==1:
|
||||||
|
self.dialogue_listener_rnn = nn.LSTM(drnn_in,dlist_hidden)
|
||||||
|
self.dialogue_listener = lambda x: \
|
||||||
|
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||||
|
elif seq_model_type==2:
|
||||||
|
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0],x.shape[0]),diagonal=1).bool().to(self.device)
|
||||||
|
sincos_fun = lambda x:torch.transpose(torch.stack([
|
||||||
|
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
|
||||||
|
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
|
||||||
|
]),0,1).reshape(-1,1,2)
|
||||||
|
self.dialogue_listener_lin1 = nn.Linear(drnn_in,dlist_hidden-2)
|
||||||
|
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
|
||||||
|
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x,x,x,attn_mask=mask_fun(x))
|
||||||
|
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
|
||||||
|
sincos_fun(x.shape[0]).float().to(self.device),
|
||||||
|
self.dialogue_listener_lin1(x).reshape(-1,1,dlist_hidden-2)
|
||||||
|
], axis=-1))[0]
|
||||||
|
else:
|
||||||
|
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
conv_block = lambda i,o,k,p,s: nn.Sequential(
|
||||||
|
nn.Conv2d( i, o, k, padding=p, stride=s),
|
||||||
|
nn.BatchNorm2d(o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.BatchNorm2d(o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
conv_block( 3, 8, 3, 1, 1),
|
||||||
|
conv_block( 8, 32, 5, 2, 2),
|
||||||
|
conv_block( 32, frame_emb//4, 5, 2, 2),
|
||||||
|
nn.Conv2d( frame_emb//4, frame_emb, 3),nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
qlayer = lambda i,o : nn.Sequential(
|
||||||
|
nn.Linear(i,512),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.Linear(512,o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
q_in_size = plan_emb_out+dlist_hidden+q_emb
|
||||||
|
|
||||||
|
self.q01 = qlayer(q_in_size,2)
|
||||||
|
self.q02 = qlayer(q_in_size,2)
|
||||||
|
self.q03 = qlayer(q_in_size,2)
|
||||||
|
|
||||||
|
self.q11 = qlayer(q_in_size,3)
|
||||||
|
self.q12 = qlayer(q_in_size,3)
|
||||||
|
self.q13 = qlayer(q_in_size,22)
|
||||||
|
|
||||||
|
self.q21 = qlayer(q_in_size,3)
|
||||||
|
self.q22 = qlayer(q_in_size,3)
|
||||||
|
self.q23 = qlayer(q_in_size,22)
|
||||||
|
|
||||||
|
def forward(self,game,global_plan=False, player_plan=False, intermediate=False):
|
||||||
|
retval = []
|
||||||
|
|
||||||
|
l = list(game)
|
||||||
|
_,d,_,q,f,_,_,m = zip(*list(game))
|
||||||
|
|
||||||
|
parse_move = lambda m: np.concatenate([
|
||||||
|
onehot(m[0][1], 2),
|
||||||
|
onehot(m[0][2][0]+1, len(game.dialogue_move_labels_dict)),
|
||||||
|
onehot(m[0][2][1]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1),
|
||||||
|
onehot(m[0][2][2]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1),
|
||||||
|
onehot(m[0][2][3]+1, len(game.materials_dict)+len(game.mines_dict)+len(game.tools_dict)+1)
|
||||||
|
])
|
||||||
|
m = np.stack([np.zeros(self.move_emb) if move is None else parse_move(move) for move in m])
|
||||||
|
|
||||||
|
f = np.array(f, dtype=np.uint8)
|
||||||
|
d = np.stack([np.concatenate(([int(x[0][1]==2),int(x[0][1]==1)],x[0][-1])) if not x is None else np.zeros(1026) for x in d])
|
||||||
|
def parse_q(q):
|
||||||
|
if not q is None:
|
||||||
|
q ,l = q
|
||||||
|
q = np.concatenate([
|
||||||
|
onehot(q[2],2),
|
||||||
|
onehot(q[3],2),
|
||||||
|
onehot(q[4][0][0]+1,2),
|
||||||
|
onehot(game.materials_dict[q[4][0][1]],len(game.materials_dict)),
|
||||||
|
onehot(q[4][1][0]+1,2),
|
||||||
|
onehot(game.materials_dict[q[4][1][1]],len(game.materials_dict)),
|
||||||
|
onehot(q[4][2]+1,2),
|
||||||
|
onehot(q[5][0][0]+1,2),
|
||||||
|
onehot(game.materials_dict[q[5][0][1]],len(game.materials_dict)),
|
||||||
|
onehot(q[5][1][0]+1,2),
|
||||||
|
onehot(game.materials_dict[q[5][1][1]],len(game.materials_dict)),
|
||||||
|
onehot(q[5][2]+1,2)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
q = np.zeros(100)
|
||||||
|
l = None
|
||||||
|
return q, l
|
||||||
|
try:
|
||||||
|
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
|
||||||
|
sel2 = 1 - sel1
|
||||||
|
except Exception as e:
|
||||||
|
sel1 = 0
|
||||||
|
sel2 = 0
|
||||||
|
q = [parse_q(x) for x in q]
|
||||||
|
q, l = zip(*q)
|
||||||
|
q = np.stack(q)
|
||||||
|
|
||||||
|
if not global_plan and not player_plan:
|
||||||
|
# plan_emb = torch.cat([
|
||||||
|
# self.plan_embedder0(game.global_plan),
|
||||||
|
# self.plan_embedder1(game.player1_plan),
|
||||||
|
# self.plan_embedder2(game.player2_plan)
|
||||||
|
# ])
|
||||||
|
plan_emb = self.plan_embedder0(game.global_plan)
|
||||||
|
plan_emb = 0*plan_emb
|
||||||
|
elif global_plan:
|
||||||
|
# plan_emb = torch.cat([
|
||||||
|
# self.plan_embedder0(game.global_plan),
|
||||||
|
# self.plan_embedder1(game.player1_plan),
|
||||||
|
# self.plan_embedder2(game.player2_plan)
|
||||||
|
# ])
|
||||||
|
plan_emb = self.plan_embedder0(game.global_plan)
|
||||||
|
else:
|
||||||
|
# plan_emb = torch.cat([
|
||||||
|
# 0*self.plan_embedder0(game.global_plan),
|
||||||
|
# sel1*self.plan_embedder1(game.player1_plan),
|
||||||
|
# sel2*self.plan_embedder2(game.player2_plan)
|
||||||
|
# ])
|
||||||
|
if sel1:
|
||||||
|
plan_emb = self.plan_embedder1(game.player1_plan)
|
||||||
|
elif sel2:
|
||||||
|
plan_emb = self.plan_embedder2(game.player2_plan)
|
||||||
|
else:
|
||||||
|
plan_emb = self.plan_embedder0(game.global_plan)
|
||||||
|
plan_emb = 0*plan_emb
|
||||||
|
|
||||||
|
u = torch.cat((
|
||||||
|
torch.tensor(d).float().to(self.device),
|
||||||
|
torch.tensor(q).float().to(self.device),
|
||||||
|
self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512),
|
||||||
|
torch.tensor(m).float().to(self.device)
|
||||||
|
),axis=-1)
|
||||||
|
u = u.float().to(self.device)
|
||||||
|
|
||||||
|
y = self.dialogue_listener(u)
|
||||||
|
y = y.reshape(-1,y.shape[-1])
|
||||||
|
|
||||||
|
if intermediate:
|
||||||
|
return y
|
||||||
|
|
||||||
|
if all([x is None for x in l]):
|
||||||
|
return []
|
||||||
|
|
||||||
|
fun_lst = [self.q01,self.q02,self.q03,self.q11,self.q12,self.q13,self.q21,self.q22,self.q23]
|
||||||
|
fun = lambda x: [f(x) for f in fun_lst]
|
||||||
|
|
||||||
|
retval = [(_l,fun(torch.cat((plan_emb,torch.tensor(_q).float().to(self.device),_y)))) for _y, _q, _l in zip(y,q,l)if not _l is None]
|
||||||
|
return retval
|
225
src/models/plan_model.py
Executable file
225
src/models/plan_model.py
Executable file
|
@ -0,0 +1,225 @@
|
||||||
|
import sys, torch, random
|
||||||
|
from numpy.core.fromnumeric import reshape
|
||||||
|
import torch.nn as nn, numpy as np
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from src.data.game_parser import DEVICE
|
||||||
|
|
||||||
|
def onehot(x,n):
|
||||||
|
retval = np.zeros(n)
|
||||||
|
if x > 0:
|
||||||
|
retval[x-1] = 1
|
||||||
|
return retval
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, seq_model_type=0,device=DEVICE):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
my_rnn = lambda i,o: nn.GRU(i,o)
|
||||||
|
#my_rnn = lambda i,o: nn.LSTM(i,o)
|
||||||
|
|
||||||
|
plan_emb_in = 81
|
||||||
|
plan_emb_out = 32
|
||||||
|
q_emb = 100
|
||||||
|
|
||||||
|
self.plan_embedder0 = my_rnn(plan_emb_in,plan_emb_out)
|
||||||
|
self.plan_embedder1 = my_rnn(plan_emb_in,plan_emb_out)
|
||||||
|
self.plan_embedder2 = my_rnn(plan_emb_in,plan_emb_out)
|
||||||
|
|
||||||
|
# self.dialogue_listener = my_rnn(1126,768)
|
||||||
|
dlist_hidden = 1024
|
||||||
|
frame_emb = 512
|
||||||
|
drnn_in = 5*1024 + 2 + frame_emb + 1024
|
||||||
|
# drnn_in = 1024 + 2
|
||||||
|
|
||||||
|
# my_rnn = lambda i,o: nn.GRU(i,o)
|
||||||
|
my_rnn = lambda i,o: nn.LSTM(i,o)
|
||||||
|
|
||||||
|
if seq_model_type==0:
|
||||||
|
self.dialogue_listener_rnn = nn.GRU(drnn_in,dlist_hidden)
|
||||||
|
self.dialogue_listener = lambda x: \
|
||||||
|
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||||
|
elif seq_model_type==1:
|
||||||
|
self.dialogue_listener_rnn = nn.LSTM(drnn_in,dlist_hidden)
|
||||||
|
self.dialogue_listener = lambda x: \
|
||||||
|
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||||
|
elif seq_model_type==2:
|
||||||
|
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0],x.shape[0]),diagonal=1).bool().to(device)
|
||||||
|
sincos_fun = lambda x:torch.transpose(torch.stack([
|
||||||
|
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
|
||||||
|
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
|
||||||
|
]),0,1).reshape(-1,1,2)
|
||||||
|
self.dialogue_listener_lin1 = nn.Linear(drnn_in,dlist_hidden-2)
|
||||||
|
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
|
||||||
|
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x,x,x,attn_mask=mask_fun(x))
|
||||||
|
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
|
||||||
|
sincos_fun(x.shape[0]).float().to(self.device),
|
||||||
|
self.dialogue_listener_lin1(x).reshape(-1,1,dlist_hidden-2)
|
||||||
|
], axis=-1))[0]
|
||||||
|
else:
|
||||||
|
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
conv_block = lambda i,o,k,p,s: nn.Sequential(
|
||||||
|
nn.Conv2d( i, o, k, padding=p, stride=s),
|
||||||
|
nn.BatchNorm2d(o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.BatchNorm2d(o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
conv_block( 3, 8, 3, 1, 1),
|
||||||
|
# conv_block( 3, 8, 5, 2, 2),
|
||||||
|
conv_block( 8, 32, 5, 2, 2),
|
||||||
|
conv_block( 32, frame_emb//4, 5, 2, 2),
|
||||||
|
nn.Conv2d( frame_emb//4, frame_emb, 3),nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
plan_layer = lambda i,o : nn.Sequential(
|
||||||
|
# nn.Linear(i,(i+o)//2),
|
||||||
|
nn.Linear(i,(i+2*o)//3),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
# nn.Linear((i+o)//2,o),
|
||||||
|
nn.Linear((i+2*o)//3,o),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
# nn.Sigmoid()
|
||||||
|
)
|
||||||
|
|
||||||
|
plan_mat_size = 21*21
|
||||||
|
q_in_size = 3*plan_emb_out+dlist_hidden
|
||||||
|
q_in_size = 3*plan_emb_out+dlist_hidden+plan_mat_size
|
||||||
|
q_in_size = dlist_hidden+plan_mat_size
|
||||||
|
|
||||||
|
# self.plan_out = plan_layer(q_in_size,plan_mat_size)
|
||||||
|
self.plan_out = plan_layer(q_in_size,plan_mat_size)
|
||||||
|
# self.q01 = qlayer(q_in_size,2)
|
||||||
|
# self.q02 = qlayer(q_in_size,2)
|
||||||
|
# self.q03 = qlayer(q_in_size,2)
|
||||||
|
|
||||||
|
# self.q11 = qlayer(q_in_size,3)
|
||||||
|
# self.q12 = qlayer(q_in_size,3)
|
||||||
|
# self.q13 = qlayer(q_in_size,22)
|
||||||
|
|
||||||
|
# self.q21 = qlayer(q_in_size,3)
|
||||||
|
# self.q22 = qlayer(q_in_size,3)
|
||||||
|
# self.q23 = qlayer(q_in_size,22)
|
||||||
|
|
||||||
|
def forward(self,game,global_plan=False, player_plan=False,evaluation=False, incremental=False):
|
||||||
|
retval = []
|
||||||
|
|
||||||
|
l = list(game)
|
||||||
|
_,d,l,q,f,_,intermediate,_ = zip(*list(game))
|
||||||
|
# print(np.array(intermediate).shape)
|
||||||
|
# exit()
|
||||||
|
|
||||||
|
h = None
|
||||||
|
intermediate = np.array(intermediate)
|
||||||
|
f = np.array(f, dtype=np.uint8)
|
||||||
|
# f = torch.tensor(f).permute(0,3,1,2).float().to(self.device)
|
||||||
|
# flt_lst = [(a,b) for a,b in zip(d,q) if (not a is None) or (not b is None)]
|
||||||
|
# if not flt_lst:
|
||||||
|
# return []
|
||||||
|
# d,q = zip(*flt_lst)
|
||||||
|
d = np.stack([np.concatenate(([int(x[0][1]==2),int(x[0][1]==1)],x[0][-1])) if not x is None else np.zeros(1026) for x in d])
|
||||||
|
# def parse_q(q):
|
||||||
|
# if not q is None:
|
||||||
|
# q ,l = q
|
||||||
|
# q = np.concatenate([
|
||||||
|
# onehot(q[2],2),
|
||||||
|
# onehot(q[3],2),
|
||||||
|
# onehot(q[4][0][0]+1,2),
|
||||||
|
# onehot(game.materials_dict[q[4][0][1]],len(game.materials_dict)),
|
||||||
|
# onehot(q[4][1][0]+1,2),
|
||||||
|
# onehot(game.materials_dict[q[4][1][1]],len(game.materials_dict)),
|
||||||
|
# onehot(q[4][2]+1,2),
|
||||||
|
# onehot(q[5][0][0]+1,2),
|
||||||
|
# onehot(game.materials_dict[q[5][0][1]],len(game.materials_dict)),
|
||||||
|
# onehot(q[5][1][0]+1,2),
|
||||||
|
# onehot(game.materials_dict[q[5][1][1]],len(game.materials_dict)),
|
||||||
|
# onehot(q[5][2]+1,2)
|
||||||
|
# ])
|
||||||
|
# else:
|
||||||
|
# q = np.zeros(100)
|
||||||
|
# l = None
|
||||||
|
# return q, l
|
||||||
|
try:
|
||||||
|
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
|
||||||
|
sel2 = 1 - sel1
|
||||||
|
except Exception as e:
|
||||||
|
sel1 = 0
|
||||||
|
sel2 = 0
|
||||||
|
# q = [parse_q(x) for x in q]
|
||||||
|
# q, l = zip(*q)
|
||||||
|
|
||||||
|
|
||||||
|
if not global_plan and not player_plan:
|
||||||
|
plan_emb = torch.cat([
|
||||||
|
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||||
|
])
|
||||||
|
plan_emb = 0*plan_emb
|
||||||
|
elif global_plan:
|
||||||
|
plan_emb = torch.cat([
|
||||||
|
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
plan_emb = torch.cat([
|
||||||
|
0*self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
sel1*self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
sel2*self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||||
|
])
|
||||||
|
|
||||||
|
# if sel1 == 0 and sel2 == 0:
|
||||||
|
# print(torch.unique(plan_emb))
|
||||||
|
|
||||||
|
u = torch.cat((
|
||||||
|
torch.tensor(d).float().to(self.device),
|
||||||
|
# torch.tensor(q).float().to(self.device),
|
||||||
|
self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512),
|
||||||
|
torch.tensor(intermediate).float().to(self.device)
|
||||||
|
),axis=-1)
|
||||||
|
u = u.float().to(self.device)
|
||||||
|
# print(d.shape)
|
||||||
|
# print(self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512).shape)
|
||||||
|
# print(intermediate.shape)
|
||||||
|
# print(u.shape)
|
||||||
|
|
||||||
|
y = self.dialogue_listener(u)
|
||||||
|
y = y.reshape(-1,y.shape[-1])
|
||||||
|
# print(y[-1].shape,plan_emb.shape,torch.tensor(game.plan_repr).float().to(self.device).shape)
|
||||||
|
# return self.plan_out(torch.cat((y[-1],plan_emb))), y
|
||||||
|
# return self.plan_out(torch.cat((y[-1],plan_emb,torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device)))), y
|
||||||
|
if incremental:
|
||||||
|
prediction = torch.stack([
|
||||||
|
self.plan_out(torch.cat((y[0],torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device))))] + [
|
||||||
|
self.plan_out(torch.cat((f,torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device)))) for f in y[len(y)%10-1::10]
|
||||||
|
])
|
||||||
|
prediction = F.softmax(prediction.reshape(-1,21,21),-1).reshape(-1,21*21)
|
||||||
|
else:
|
||||||
|
prediction = self.plan_out(torch.cat((y[-1],torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device))))
|
||||||
|
prediction = F.softmax(prediction.reshape(21,21),-1).reshape(21*21)
|
||||||
|
# prediction = F.softmax(prediction,-1)
|
||||||
|
# prediction = F.softmax(prediction,-1)
|
||||||
|
# exit()
|
||||||
|
return prediction, y
|
||||||
|
|
||||||
|
# exit()
|
||||||
|
# if all([x is None for x in l]):
|
||||||
|
# return []
|
||||||
|
|
||||||
|
# fun_lst = [self.q01,self.q02,self.q03,self.q11,self.q12,self.q13,self.q21,self.q22,self.q23]
|
||||||
|
# fun = lambda x: [f(x) for f in fun_lst]
|
||||||
|
|
||||||
|
|
||||||
|
# retval = [(_l,fun(torch.cat((plan_emb,torch.tensor(_q).float().to(self.device),_y)))) for _y, _q, _l in zip(y,q,l) if not _l is None]
|
||||||
|
# return retval
|
230
src/models/plan_model_graphs.py
Normal file
230
src/models/plan_model_graphs.py
Normal file
|
@ -0,0 +1,230 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn, numpy as np
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch_geometric.nn import MeanAggregation, GATv2Conv
|
||||||
|
|
||||||
|
|
||||||
|
class PlanGraphEmbedder(nn.Module):
|
||||||
|
def __init__(self, device, h_dim, dropout=0.0, heads=4):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.proj_x = nn.Sequential(
|
||||||
|
nn.Linear(27, h_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
self.proj_edge_attr = nn.Sequential(
|
||||||
|
nn.Linear(12, h_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
self.conv1 = GATv2Conv(h_dim, h_dim, heads=heads, edge_dim=h_dim)
|
||||||
|
self.conv2 = GATv2Conv(h_dim*heads, h_dim, heads=1, edge_dim=h_dim)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.dec = nn.Linear(h_dim*3, 1)
|
||||||
|
|
||||||
|
def encode(self, data):
|
||||||
|
x, edge_index, edge_attr = data.features.to(self.device), data.edge_index.to(self.device), data.tool.to(self.device)
|
||||||
|
x = self.proj_x(x)
|
||||||
|
edge_attr = self.proj_edge_attr(edge_attr)
|
||||||
|
x = self.conv1(x, edge_index, edge_attr)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.conv2(x, edge_index, edge_attr)
|
||||||
|
return x, edge_attr
|
||||||
|
|
||||||
|
def decode(self, z, context, edge_label_index):
|
||||||
|
u = z[edge_label_index[0]]
|
||||||
|
v = z[edge_label_index[1]]
|
||||||
|
return self.dec(torch.cat((u, v, context), -1))
|
||||||
|
|
||||||
|
# def decode(self, z, edge_index, edge_attr, edge_label_index):
|
||||||
|
# z = self.conv3(z, edge_index.to(self.device), edge_attr)
|
||||||
|
# return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, seq_model_type, device):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
plan_emb_out = 128
|
||||||
|
self.plan_embedder0 = PlanGraphEmbedder(device, plan_emb_out)
|
||||||
|
self.plan_embedder1 = PlanGraphEmbedder(device, plan_emb_out)
|
||||||
|
self.plan_embedder2 = PlanGraphEmbedder(device, plan_emb_out)
|
||||||
|
self.plan_pool = MeanAggregation()
|
||||||
|
dlist_hidden = 1024
|
||||||
|
frame_emb = 512
|
||||||
|
drnn_in = 5*1024 + 2 + frame_emb + 1024
|
||||||
|
self.dialogue_listener_pre_ln = nn.LayerNorm(drnn_in)
|
||||||
|
if seq_model_type==0:
|
||||||
|
self.dialogue_listener_rnn = nn.GRU(drnn_in, dlist_hidden)
|
||||||
|
self.dialogue_listener = lambda x: \
|
||||||
|
self.dialogue_listener_rnn(x.reshape(-1, 1, drnn_in))[0]
|
||||||
|
elif seq_model_type==1:
|
||||||
|
self.dialogue_listener_rnn = nn.LSTM(drnn_in, dlist_hidden)
|
||||||
|
self.dialogue_listener = lambda x: \
|
||||||
|
self.dialogue_listener_rnn(x.reshape(-1, 1, drnn_in))[0]
|
||||||
|
elif seq_model_type==2:
|
||||||
|
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1).bool().to(device)
|
||||||
|
sincos_fun = lambda x:torch.transpose(torch.stack([
|
||||||
|
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
|
||||||
|
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
|
||||||
|
]),0,1).reshape(-1,1,2)
|
||||||
|
self.dialogue_listener_lin1 = nn.Linear(drnn_in, dlist_hidden-2)
|
||||||
|
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
|
||||||
|
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x, x, x, attn_mask=mask_fun(x))
|
||||||
|
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
|
||||||
|
sincos_fun(x.shape[0]).float().to(self.device),
|
||||||
|
self.dialogue_listener_lin1(x).reshape(-1, 1, dlist_hidden-2)
|
||||||
|
], axis=-1))[0]
|
||||||
|
else:
|
||||||
|
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
|
||||||
|
exit()
|
||||||
|
conv_block = lambda i, o, k, p, s: nn.Sequential(
|
||||||
|
nn.Conv2d(i, o, k, padding=p, stride=s),
|
||||||
|
nn.BatchNorm2d(o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.BatchNorm2d(o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
conv_block(3, 8, 3, 1, 1),
|
||||||
|
conv_block(8, 32, 5, 2, 2),
|
||||||
|
conv_block(32, frame_emb//4, 5, 2, 2),
|
||||||
|
nn.Conv2d(frame_emb//4, frame_emb, 3),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.proj_y = nn.Sequential(
|
||||||
|
nn.Linear(1024, 512),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.Linear(512, plan_emb_out)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, game, experiment, global_plan=False, player_plan=False, incremental=False, return_feats=False):
|
||||||
|
|
||||||
|
_, d, l, q, f, _, intermediate, _ = zip(*list(game))
|
||||||
|
intermediate = np.array(intermediate)
|
||||||
|
f = np.array(f, dtype=np.uint8)
|
||||||
|
d = np.stack([np.concatenate(([int(x[0][1]==2), int(x[0][1]==1)], x[0][-1])) if not x is None else np.zeros(1026) for x in d])
|
||||||
|
try:
|
||||||
|
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
|
||||||
|
sel2 = 1 - sel1
|
||||||
|
except Exception as e:
|
||||||
|
sel1 = 0
|
||||||
|
sel2 = 0
|
||||||
|
|
||||||
|
if player_plan:
|
||||||
|
if sel1:
|
||||||
|
z, _ = self.plan_embedder1.encode(game.player1_plan)
|
||||||
|
elif sel2:
|
||||||
|
z, _ = self.plan_embedder2.encode(game.player2_plan)
|
||||||
|
else:
|
||||||
|
z, _ = self.plan_embedder0.encode(game.global_plan)
|
||||||
|
else:
|
||||||
|
raise ValueError('There should never be a global plan!')
|
||||||
|
|
||||||
|
u = torch.cat((
|
||||||
|
torch.tensor(d).float().to(self.device),
|
||||||
|
self.conv(torch.tensor(f).permute(0, 3, 1, 2).float().to(self.device) / 255.0).reshape(-1, 512),
|
||||||
|
torch.tensor(intermediate).float().to(self.device)
|
||||||
|
), axis=-1)
|
||||||
|
u = u.float().to(self.device)
|
||||||
|
u = self.dialogue_listener_pre_ln(u)
|
||||||
|
y = self.dialogue_listener(u)
|
||||||
|
y = y.reshape(-1, y.shape[-1])
|
||||||
|
if return_feats:
|
||||||
|
_y = y.clone().detach().cpu().numpy()
|
||||||
|
y = self.proj_y(y)
|
||||||
|
|
||||||
|
if experiment == 2:
|
||||||
|
pred, label = self.decode_own_missing_knowledge(z, y, game, sel1, sel2, incremental)
|
||||||
|
elif experiment == 3:
|
||||||
|
pred, label = self.decode_partner_missing_knowledge(z, y, game, sel1, sel2, incremental)
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong experiment id! Valid values are 2 and 3.')
|
||||||
|
|
||||||
|
if return_feats:
|
||||||
|
return pred, label, [sel1, sel2], _y
|
||||||
|
|
||||||
|
return pred, label, [sel1, sel2]
|
||||||
|
|
||||||
|
def decode_own_missing_knowledge(self, z, y, game, sel1, sel2, incremental):
|
||||||
|
if incremental:
|
||||||
|
if sel1:
|
||||||
|
pred = torch.stack(
|
||||||
|
[self.plan_embedder1.decode(z, y[0].repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1)]
|
||||||
|
+
|
||||||
|
[self.plan_embedder1.decode(z, f.repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||||
|
)
|
||||||
|
label = game.player1_edge_label_own_missing_knowledge.to(self.device)
|
||||||
|
elif sel2:
|
||||||
|
pred = torch.stack(
|
||||||
|
[self.plan_embedder2.decode(z, y[0].repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1)]
|
||||||
|
+
|
||||||
|
[self.plan_embedder2.decode(z, f.repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||||
|
)
|
||||||
|
label = game.player2_edge_label_own_missing_knowledge.to(self.device)
|
||||||
|
else:
|
||||||
|
pred = torch.stack(
|
||||||
|
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||||
|
[self.plan_embedder0.decode(z, y[0].repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)]
|
||||||
|
+
|
||||||
|
[self.plan_embedder0.decode(z, f.repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||||
|
)
|
||||||
|
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||||
|
else:
|
||||||
|
if sel1:
|
||||||
|
pred = self.plan_embedder1.decode(z, y.mean(0, keepdim=True).repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1)
|
||||||
|
label = game.player1_edge_label_own_missing_knowledge.to(self.device)
|
||||||
|
elif sel2:
|
||||||
|
pred = self.plan_embedder2.decode(z, y.mean(0, keepdim=True).repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1)
|
||||||
|
label = game.player2_edge_label_own_missing_knowledge.to(self.device)
|
||||||
|
else:
|
||||||
|
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||||
|
pred = self.plan_embedder0.decode(z, y.mean(0, keepdim=True).repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)
|
||||||
|
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||||
|
|
||||||
|
return (pred, label)
|
||||||
|
|
||||||
|
def decode_partner_missing_knowledge(self, z, y, game, sel1, sel2, incremental):
|
||||||
|
if incremental:
|
||||||
|
if sel1:
|
||||||
|
pred = torch.stack(
|
||||||
|
[self.plan_embedder1.decode(z, y[0].repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1)]
|
||||||
|
+
|
||||||
|
[self.plan_embedder1.decode(z, f.repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||||
|
)
|
||||||
|
label = game.player1_edge_label_other_missing_knowledge.to(self.device)
|
||||||
|
elif sel2:
|
||||||
|
pred = torch.stack(
|
||||||
|
[self.plan_embedder2.decode(z, y[0].repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1)]
|
||||||
|
+
|
||||||
|
[self.plan_embedder2.decode(z, f.repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||||
|
)
|
||||||
|
label = game.player2_edge_label_other_missing_knowledge.to(self.device)
|
||||||
|
else:
|
||||||
|
pred = torch.stack(
|
||||||
|
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||||
|
[self.plan_embedder0.decode(z, y[0].repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)]
|
||||||
|
+
|
||||||
|
[self.plan_embedder0.decode(z, f.repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||||
|
)
|
||||||
|
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||||
|
else:
|
||||||
|
if sel1:
|
||||||
|
pred = self.plan_embedder1.decode(z, y.mean(0, keepdim=True).repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1)
|
||||||
|
label = game.player1_edge_label_other_missing_knowledge.to(self.device)
|
||||||
|
elif sel2:
|
||||||
|
pred = self.plan_embedder2.decode(z, y.mean(0, keepdim=True).repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1)
|
||||||
|
label = game.player2_edge_label_other_missing_knowledge.to(self.device)
|
||||||
|
else:
|
||||||
|
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||||
|
pred = self.plan_embedder0.decode(z, y.mean(0, keepdim=True).repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)
|
||||||
|
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||||
|
|
||||||
|
return (pred, label)
|
||||||
|
|
291
src/models/plan_model_graphs_oracle.py
Normal file
291
src/models/plan_model_graphs_oracle.py
Normal file
|
@ -0,0 +1,291 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn, numpy as np
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from torch_geometric.nn import MeanAggregation, GATv2Conv
|
||||||
|
|
||||||
|
|
||||||
|
def onehot(x,n):
|
||||||
|
retval = np.zeros(n)
|
||||||
|
if x > 0:
|
||||||
|
retval[x-1] = 1
|
||||||
|
return retval
|
||||||
|
|
||||||
|
class PlanGraphEmbedder(nn.Module):
|
||||||
|
def __init__(self, device, h_dim, dropout=0.0, heads=4):
|
||||||
|
super().__init__()
|
||||||
|
self.device = device
|
||||||
|
self.proj_x = nn.Sequential(
|
||||||
|
nn.Linear(27, h_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
self.proj_edge_attr = nn.Sequential(
|
||||||
|
nn.Linear(12, h_dim),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(dropout)
|
||||||
|
)
|
||||||
|
self.conv1 = GATv2Conv(h_dim, h_dim, heads=heads, edge_dim=h_dim)
|
||||||
|
self.conv2 = GATv2Conv(h_dim*heads, h_dim, heads=1, edge_dim=h_dim)
|
||||||
|
self.act = nn.GELU()
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
self.dec = nn.Linear(h_dim*3, 1)
|
||||||
|
|
||||||
|
def encode(self, data):
|
||||||
|
x, edge_index, edge_attr = data.features.to(self.device), data.edge_index.to(self.device), data.tool.to(self.device)
|
||||||
|
x = self.proj_x(x)
|
||||||
|
edge_attr = self.proj_edge_attr(edge_attr)
|
||||||
|
x = self.conv1(x, edge_index, edge_attr)
|
||||||
|
x = self.act(x)
|
||||||
|
x = self.dropout(x)
|
||||||
|
x = self.conv2(x, edge_index, edge_attr)
|
||||||
|
return x, edge_attr
|
||||||
|
|
||||||
|
def decode(self, z, context, edge_label_index):
|
||||||
|
u = z[edge_label_index[0]]
|
||||||
|
v = z[edge_label_index[1]]
|
||||||
|
return self.dec(torch.cat((u, v, context), -1))
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, seq_model_type, device):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
plan_emb_out = 128
|
||||||
|
self.plan_embedder0 = PlanGraphEmbedder(device, plan_emb_out)
|
||||||
|
self.plan_embedder1 = PlanGraphEmbedder(device, plan_emb_out)
|
||||||
|
self.plan_embedder2 = PlanGraphEmbedder(device, plan_emb_out)
|
||||||
|
self.plan_pool = MeanAggregation()
|
||||||
|
dlist_hidden = 1024
|
||||||
|
frame_emb = 512
|
||||||
|
drnn_in = 5*1024 + 2 + frame_emb + 1024
|
||||||
|
self.dialogue_listener_pre_ln = nn.LayerNorm(drnn_in)
|
||||||
|
if seq_model_type==0:
|
||||||
|
self.dialogue_listener_rnn = nn.GRU(drnn_in, dlist_hidden)
|
||||||
|
self.dialogue_listener = lambda x: \
|
||||||
|
self.dialogue_listener_rnn(x.reshape(-1, 1, drnn_in))[0]
|
||||||
|
elif seq_model_type==1:
|
||||||
|
self.dialogue_listener_rnn = nn.LSTM(drnn_in, dlist_hidden)
|
||||||
|
self.dialogue_listener = lambda x: \
|
||||||
|
self.dialogue_listener_rnn(x.reshape(-1, 1, drnn_in))[0]
|
||||||
|
elif seq_model_type==2:
|
||||||
|
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0], x.shape[0]), diagonal=1).bool().to(device)
|
||||||
|
sincos_fun = lambda x:torch.transpose(torch.stack([
|
||||||
|
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
|
||||||
|
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
|
||||||
|
]),0,1).reshape(-1,1,2)
|
||||||
|
self.dialogue_listener_lin1 = nn.Linear(drnn_in, dlist_hidden-2)
|
||||||
|
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
|
||||||
|
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x, x, x, attn_mask=mask_fun(x))
|
||||||
|
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
|
||||||
|
sincos_fun(x.shape[0]).float().to(self.device),
|
||||||
|
self.dialogue_listener_lin1(x).reshape(-1, 1, dlist_hidden-2)
|
||||||
|
], axis=-1))[0]
|
||||||
|
else:
|
||||||
|
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
|
||||||
|
exit()
|
||||||
|
conv_block = lambda i, o, k, p, s: nn.Sequential(
|
||||||
|
nn.Conv2d(i, o, k, padding=p, stride=s),
|
||||||
|
nn.BatchNorm2d(o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.BatchNorm2d(o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
conv_block(3, 8, 3, 1, 1),
|
||||||
|
conv_block(8, 32, 5, 2, 2),
|
||||||
|
conv_block(32, frame_emb//4, 5, 2, 2),
|
||||||
|
nn.Conv2d(frame_emb//4, frame_emb, 3),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.proj_y = nn.Sequential(
|
||||||
|
nn.Linear(1024, 512),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.Linear(512, plan_emb_out)
|
||||||
|
)
|
||||||
|
self.proj_tom = nn.Linear(154, 5*1024)
|
||||||
|
|
||||||
|
def parse_ql(self, q, game, intermediate):
|
||||||
|
tom12_answ = ['YES', 'NO', 'MAYBE']
|
||||||
|
materials_dict = game.materials_dict.copy()
|
||||||
|
materials_dict['NOT_SURE'] = 0
|
||||||
|
if not q is None:
|
||||||
|
q, l = q
|
||||||
|
tom_gt = np.concatenate([onehot(q[2],2), onehot(q[3],2)])
|
||||||
|
#### q1
|
||||||
|
q1_1 = np.concatenate([onehot(q[4][0][0]+1,2), onehot(materials_dict[q[4][0][1]], len(game.materials_dict))])
|
||||||
|
## r1
|
||||||
|
r1_1 = np.eye(len(tom12_answ))[tom12_answ.index(l[0][0])]
|
||||||
|
#### q2
|
||||||
|
q2_1 = np.concatenate([onehot(q[4][1][0]+1,2), onehot(materials_dict[q[4][1][1]], len(game.materials_dict))])
|
||||||
|
## r2
|
||||||
|
r2_1 = np.eye(len(tom12_answ))[tom12_answ.index(l[0][1])]
|
||||||
|
#### q3
|
||||||
|
q3_1 = onehot(q[4][2]+1, 2)
|
||||||
|
## r3
|
||||||
|
r3_1 = onehot(materials_dict[l[0][2]], len(game.materials_dict))
|
||||||
|
#### q1
|
||||||
|
q1_2 = np.concatenate([onehot(q[5][0][0]+1,2), onehot(materials_dict[q[5][0][1]], len(game.materials_dict))])
|
||||||
|
## r1
|
||||||
|
r1_2 = np.eye(len(tom12_answ))[tom12_answ.index(l[1][0])]
|
||||||
|
#### q2
|
||||||
|
q2_2 = np.concatenate([onehot(q[5][1][0]+1,2), onehot(materials_dict[q[5][1][1]], len(game.materials_dict))])
|
||||||
|
## r2
|
||||||
|
r2_2 = np.eye(len(tom12_answ))[tom12_answ.index(l[1][1])]
|
||||||
|
#### q3
|
||||||
|
q3_2 = onehot(q[5][2]+1,2)
|
||||||
|
## r3
|
||||||
|
r3_2 = onehot(materials_dict[l[1][2]], len(game.materials_dict))
|
||||||
|
if intermediate == 0:
|
||||||
|
tom_gt = np.zeros(154)
|
||||||
|
elif intermediate == 1:
|
||||||
|
# tom6
|
||||||
|
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, np.zeros(q2_1.shape[0] + r2_1.shape[0] + q3_1.shape[0] + r3_1.shape[0]), q1_2, r1_2, np.zeros(q2_2.shape[0] + r2_2.shape[0] + q3_2.shape[0] + r3_2.shape[0])])
|
||||||
|
elif intermediate == 2:
|
||||||
|
# tom7
|
||||||
|
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0]), q2_1, r2_1, np.zeros(q3_1.shape[0] + r3_1.shape[0] + q1_2.shape[0] + r1_2.shape[0]), q2_2, r2_2, np.zeros(q3_2.shape[0] + r3_2.shape[0])])
|
||||||
|
elif intermediate == 3:
|
||||||
|
# tom6 + tom7
|
||||||
|
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, q2_1, r2_1, np.zeros(q3_1.shape[0] + r3_1.shape[0]), q1_2, r1_2, q2_2, r2_2, np.zeros(q3_2.shape[0] + r3_2.shape[0])])
|
||||||
|
elif intermediate == 4:
|
||||||
|
# tom8
|
||||||
|
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0] + q2_1.shape[0] + r2_1.shape[0]), q3_1, r3_1, np.zeros(q1_2.shape[0] + r1_2.shape[0] + q2_2.shape[0] + r2_2.shape[0]), q3_2, r3_2])
|
||||||
|
elif intermediate == 5:
|
||||||
|
# tom6 + tom8
|
||||||
|
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, np.zeros(q2_1.shape[0] + r2_1.shape[0]), q3_1, r3_1, q1_2, r1_2, np.zeros(q2_2.shape[0] + r2_2.shape[0]), q3_2, r3_2])
|
||||||
|
elif intermediate == 6:
|
||||||
|
# tom7 + tom8
|
||||||
|
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0]), q2_1, r2_1, q3_1, r3_1, np.zeros(q1_2.shape[0] + r1_2.shape[0]), q2_2, r2_2, q3_2, r3_2])
|
||||||
|
elif intermediate == 7:
|
||||||
|
# tom6 + tom7 + tom8
|
||||||
|
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, q2_1, r2_1, q3_1, r3_1, q1_2, r1_2, q2_2, r2_2, q3_2, r3_2])
|
||||||
|
else:
|
||||||
|
tom_gt = np.zeros(154)
|
||||||
|
if tom_gt.shape[0] != 154: breakpoint()
|
||||||
|
return tom_gt
|
||||||
|
|
||||||
|
def forward(self, game, experiment, global_plan=False, player_plan=False, incremental=False, intermediate=0):
|
||||||
|
|
||||||
|
l = list(game)
|
||||||
|
_, d, l, q, f, _, _, _ = zip(*list(game))
|
||||||
|
|
||||||
|
tom_gt = [self.parse_ql(x, game, intermediate) for x in q]
|
||||||
|
tom_gt = torch.tensor(np.stack(tom_gt), device=self.device, dtype=torch.float32)
|
||||||
|
tom_gt = self.proj_tom(tom_gt)
|
||||||
|
|
||||||
|
f = np.array(f, dtype=np.uint8)
|
||||||
|
d = np.stack([np.concatenate(([int(x[0][1]==2), int(x[0][1]==1)], x[0][-1])) if not x is None else np.zeros(1026) for x in d])
|
||||||
|
try:
|
||||||
|
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
|
||||||
|
sel2 = 1 - sel1
|
||||||
|
except Exception as e:
|
||||||
|
sel1 = 0
|
||||||
|
sel2 = 0
|
||||||
|
|
||||||
|
if player_plan:
|
||||||
|
if sel1:
|
||||||
|
z, _ = self.plan_embedder1.encode(game.player1_plan)
|
||||||
|
elif sel2:
|
||||||
|
z, _ = self.plan_embedder2.encode(game.player2_plan)
|
||||||
|
else:
|
||||||
|
z, _ = self.plan_embedder0.encode(game.global_plan)
|
||||||
|
else:
|
||||||
|
raise ValueError('There should never be a global plan!')
|
||||||
|
|
||||||
|
u = torch.cat((
|
||||||
|
torch.tensor(d).float().to(self.device),
|
||||||
|
self.conv(torch.tensor(f).permute(0, 3, 1, 2).float().to(self.device) / 255.0).reshape(-1, 512),
|
||||||
|
tom_gt
|
||||||
|
), axis=-1)
|
||||||
|
u = u.float().to(self.device)
|
||||||
|
u = self.dialogue_listener_pre_ln(u)
|
||||||
|
y = self.dialogue_listener(u)
|
||||||
|
y = y.reshape(-1, y.shape[-1])
|
||||||
|
y = self.proj_y(y)
|
||||||
|
|
||||||
|
if experiment == 2:
|
||||||
|
pred, label = self.decode_own_missing_knowledge(z, y, game, sel1, sel2, incremental)
|
||||||
|
elif experiment == 3:
|
||||||
|
pred, label = self.decode_partner_missing_knowledge(z, y, game, sel1, sel2, incremental)
|
||||||
|
else:
|
||||||
|
raise ValueError('Wrong experiment id! Valid values are 2 and 3.')
|
||||||
|
|
||||||
|
return pred, label, [sel1, sel2]
|
||||||
|
|
||||||
|
def decode_own_missing_knowledge(self, z, y, game, sel1, sel2, incremental):
|
||||||
|
if incremental:
|
||||||
|
if sel1:
|
||||||
|
pred = torch.stack(
|
||||||
|
[self.plan_embedder1.decode(z, y[0].repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1)]
|
||||||
|
+
|
||||||
|
[self.plan_embedder1.decode(z, f.repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||||
|
)
|
||||||
|
label = game.player1_edge_label_own_missing_knowledge.to(self.device)
|
||||||
|
elif sel2:
|
||||||
|
pred = torch.stack(
|
||||||
|
[self.plan_embedder2.decode(z, y[0].repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1)]
|
||||||
|
+
|
||||||
|
[self.plan_embedder2.decode(z, f.repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||||
|
)
|
||||||
|
label = game.player2_edge_label_own_missing_knowledge.to(self.device)
|
||||||
|
else:
|
||||||
|
pred = torch.stack(
|
||||||
|
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||||
|
[self.plan_embedder0.decode(z, y[0].repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)]
|
||||||
|
+
|
||||||
|
[self.plan_embedder0.decode(z, f.repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||||
|
)
|
||||||
|
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||||
|
else:
|
||||||
|
if sel1:
|
||||||
|
pred = self.plan_embedder1.decode(z, y.mean(0, keepdim=True).repeat(game.player1_edge_label_index.shape[1], 1), game.player1_edge_label_index).view(-1)
|
||||||
|
label = game.player1_edge_label_own_missing_knowledge.to(self.device)
|
||||||
|
elif sel2:
|
||||||
|
pred = self.plan_embedder2.decode(z, y.mean(0, keepdim=True).repeat(game.player2_edge_label_index.shape[1], 1), game.player2_edge_label_index).view(-1)
|
||||||
|
label = game.player2_edge_label_own_missing_knowledge.to(self.device)
|
||||||
|
else:
|
||||||
|
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||||
|
pred = self.plan_embedder0.decode(z, y.mean(0, keepdim=True).repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)
|
||||||
|
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||||
|
|
||||||
|
return (pred, label)
|
||||||
|
|
||||||
|
def decode_partner_missing_knowledge(self, z, y, game, sel1, sel2, incremental):
|
||||||
|
if incremental:
|
||||||
|
if sel1:
|
||||||
|
pred = torch.stack(
|
||||||
|
[self.plan_embedder1.decode(z, y[0].repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1)]
|
||||||
|
+
|
||||||
|
[self.plan_embedder1.decode(z, f.repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||||
|
)
|
||||||
|
label = game.player1_edge_label_other_missing_knowledge.to(self.device)
|
||||||
|
elif sel2:
|
||||||
|
pred = torch.stack(
|
||||||
|
[self.plan_embedder2.decode(z, y[0].repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1)]
|
||||||
|
+
|
||||||
|
[self.plan_embedder2.decode(z, f.repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||||
|
)
|
||||||
|
label = game.player2_edge_label_other_missing_knowledge.to(self.device)
|
||||||
|
else:
|
||||||
|
pred = torch.stack(
|
||||||
|
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||||
|
[self.plan_embedder0.decode(z, y[0].repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)]
|
||||||
|
+
|
||||||
|
[self.plan_embedder0.decode(z, f.repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1) for f in y[len(y)%10-1::10]]
|
||||||
|
)
|
||||||
|
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||||
|
else:
|
||||||
|
if sel1:
|
||||||
|
pred = self.plan_embedder1.decode(z, y.mean(0, keepdim=True).repeat(game.player1_plan.edge_index.shape[1], 1), game.player1_plan.edge_index).view(-1)
|
||||||
|
label = game.player1_edge_label_other_missing_knowledge.to(self.device)
|
||||||
|
elif sel2:
|
||||||
|
pred = self.plan_embedder2.decode(z, y.mean(0, keepdim=True).repeat(game.player2_plan.edge_index.shape[1], 1), game.player2_plan.edge_index).view(-1)
|
||||||
|
label = game.player2_edge_label_other_missing_knowledge.to(self.device)
|
||||||
|
else:
|
||||||
|
# NOTE: here I use game.global_plan.edge_index instead of edge_label_index => no negative sampling
|
||||||
|
pred = self.plan_embedder0.decode(z, y.mean(0, keepdim=True).repeat(game.global_plan.edge_index.shape[1], 1), game.global_plan.edge_index).view(-1)
|
||||||
|
label = torch.zeros(game.global_plan.edge_index.shape[1])
|
||||||
|
|
||||||
|
return (pred, label)
|
214
src/models/plan_model_oracle.py
Normal file
214
src/models/plan_model_oracle.py
Normal file
|
@ -0,0 +1,214 @@
|
||||||
|
import sys, torch, random
|
||||||
|
from numpy.core.fromnumeric import reshape
|
||||||
|
import torch.nn as nn, numpy as np
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from src.data.game_parser import DEVICE
|
||||||
|
|
||||||
|
def onehot(x,n):
|
||||||
|
retval = np.zeros(n)
|
||||||
|
if x > 0:
|
||||||
|
retval[x-1] = 1
|
||||||
|
return retval
|
||||||
|
|
||||||
|
class Model(nn.Module):
|
||||||
|
def __init__(self, seq_model_type=0,device=DEVICE):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
my_rnn = lambda i,o: nn.GRU(i,o)
|
||||||
|
|
||||||
|
plan_emb_in = 81
|
||||||
|
plan_emb_out = 32
|
||||||
|
q_emb = 100
|
||||||
|
|
||||||
|
self.plan_embedder0 = my_rnn(plan_emb_in,plan_emb_out)
|
||||||
|
self.plan_embedder1 = my_rnn(plan_emb_in,plan_emb_out)
|
||||||
|
self.plan_embedder2 = my_rnn(plan_emb_in,plan_emb_out)
|
||||||
|
|
||||||
|
dlist_hidden = 1024
|
||||||
|
frame_emb = 512
|
||||||
|
drnn_in = 5*1024 + 2 + frame_emb + 1024
|
||||||
|
|
||||||
|
my_rnn = lambda i,o: nn.LSTM(i,o)
|
||||||
|
|
||||||
|
if seq_model_type==0:
|
||||||
|
self.dialogue_listener_rnn = nn.GRU(drnn_in,dlist_hidden)
|
||||||
|
self.dialogue_listener = lambda x: \
|
||||||
|
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||||
|
elif seq_model_type==1:
|
||||||
|
self.dialogue_listener_rnn = nn.LSTM(drnn_in,dlist_hidden)
|
||||||
|
self.dialogue_listener = lambda x: \
|
||||||
|
self.dialogue_listener_rnn(x.reshape(-1,1,drnn_in))[0]
|
||||||
|
elif seq_model_type==2:
|
||||||
|
mask_fun = lambda x: torch.triu(torch.ones(x.shape[0],x.shape[0]),diagonal=1).bool().to(device)
|
||||||
|
sincos_fun = lambda x:torch.transpose(torch.stack([
|
||||||
|
torch.sin(2*np.pi*torch.tensor(list(range(x)))/x),
|
||||||
|
torch.cos(2*np.pi*torch.tensor(list(range(x)))/x)
|
||||||
|
]),0,1).reshape(-1,1,2)
|
||||||
|
self.dialogue_listener_lin1 = nn.Linear(drnn_in,dlist_hidden-2)
|
||||||
|
self.dialogue_listener_attn = nn.MultiheadAttention(dlist_hidden, 8)
|
||||||
|
self.dialogue_listener_wrap = lambda x: self.dialogue_listener_attn(x,x,x,attn_mask=mask_fun(x))
|
||||||
|
self.dialogue_listener = lambda x: self.dialogue_listener_wrap(torch.cat([
|
||||||
|
sincos_fun(x.shape[0]).float().to(self.device),
|
||||||
|
self.dialogue_listener_lin1(x).reshape(-1,1,dlist_hidden-2)
|
||||||
|
], axis=-1))[0]
|
||||||
|
else:
|
||||||
|
print('Sequence model type must be in (0: GRU, 1: LSTM, 2: Transformer), but got ', seq_model_type)
|
||||||
|
exit()
|
||||||
|
|
||||||
|
conv_block = lambda i,o,k,p,s: nn.Sequential(
|
||||||
|
nn.Conv2d( i, o, k, padding=p, stride=s),
|
||||||
|
nn.BatchNorm2d(o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.MaxPool2d(2),
|
||||||
|
nn.BatchNorm2d(o),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.ReLU()
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv = nn.Sequential(
|
||||||
|
conv_block( 3, 8, 3, 1, 1),
|
||||||
|
conv_block( 8, 32, 5, 2, 2),
|
||||||
|
conv_block( 32, frame_emb//4, 5, 2, 2),
|
||||||
|
nn.Conv2d( frame_emb//4, frame_emb, 3),nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
plan_layer = lambda i,o : nn.Sequential(
|
||||||
|
nn.Linear(i,(i+2*o)//3),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
nn.Linear((i+2*o)//3,o),
|
||||||
|
nn.GELU(),
|
||||||
|
nn.Dropout(0.5),
|
||||||
|
)
|
||||||
|
|
||||||
|
plan_mat_size = 21*21
|
||||||
|
q_in_size = 3*plan_emb_out+dlist_hidden
|
||||||
|
q_in_size = 3*plan_emb_out+dlist_hidden+plan_mat_size
|
||||||
|
q_in_size = dlist_hidden+plan_mat_size
|
||||||
|
|
||||||
|
self.plan_out = plan_layer(q_in_size,plan_mat_size)
|
||||||
|
|
||||||
|
self.proj_tom = nn.Linear(154, 5*1024)
|
||||||
|
|
||||||
|
def forward(self,game,global_plan=False, player_plan=False,evaluation=False, incremental=False, intermediate=0):
|
||||||
|
|
||||||
|
_,d,l,q,f,_,_,_ = zip(*list(game))
|
||||||
|
|
||||||
|
tom_gt = [self.parse_ql(x, game, intermediate) for x in q]
|
||||||
|
tom_gt = torch.tensor(np.stack(tom_gt), device=self.device, dtype=torch.float32)
|
||||||
|
tom_gt = self.proj_tom(tom_gt)
|
||||||
|
|
||||||
|
f = np.array(f, dtype=np.uint8)
|
||||||
|
|
||||||
|
d = np.stack([np.concatenate(([int(x[0][1]==2),int(x[0][1]==1)],x[0][-1])) if not x is None else np.zeros(1026) for x in d])
|
||||||
|
|
||||||
|
try:
|
||||||
|
sel1 = int([x[0][2] for x in q if not x is None][0] == 1)
|
||||||
|
sel2 = 1 - sel1
|
||||||
|
except Exception as e:
|
||||||
|
sel1 = 0
|
||||||
|
sel2 = 0
|
||||||
|
|
||||||
|
if not global_plan and not player_plan:
|
||||||
|
plan_emb = torch.cat([
|
||||||
|
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||||
|
])
|
||||||
|
plan_emb = 0*plan_emb
|
||||||
|
elif global_plan:
|
||||||
|
plan_emb = torch.cat([
|
||||||
|
self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
plan_emb = torch.cat([
|
||||||
|
0*self.plan_embedder0(torch.stack(list(map(torch.tensor,game.global_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
sel1*self.plan_embedder1(torch.stack(list(map(torch.tensor,game.player1_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0],
|
||||||
|
sel2*self.plan_embedder2(torch.stack(list(map(torch.tensor,game.player2_plan))).reshape(-1,1,81).float().to(self.device))[0][-1][0]
|
||||||
|
])
|
||||||
|
|
||||||
|
u = torch.cat((
|
||||||
|
torch.tensor(d).float().to(self.device),
|
||||||
|
self.conv(torch.tensor(f).permute(0,3,1,2).float().to(self.device)).reshape(-1,512),
|
||||||
|
tom_gt
|
||||||
|
),axis=-1)
|
||||||
|
u = u.float().to(self.device)
|
||||||
|
|
||||||
|
y = self.dialogue_listener(u)
|
||||||
|
y = y.reshape(-1,y.shape[-1])
|
||||||
|
|
||||||
|
if incremental:
|
||||||
|
prediction = torch.stack([
|
||||||
|
self.plan_out(torch.cat((y[0],torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device))))] + [
|
||||||
|
self.plan_out(torch.cat((f,torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device)))) for f in y[len(y)%10-1::10]
|
||||||
|
])
|
||||||
|
prediction = F.softmax(prediction.reshape(-1,21,21),-1).reshape(-1,21*21)
|
||||||
|
else:
|
||||||
|
prediction = self.plan_out(torch.cat((y[-1],torch.tensor(game.plan_repr.reshape(-1)).float().to(self.device))))
|
||||||
|
prediction = F.softmax(prediction.reshape(21,21),-1).reshape(21*21)
|
||||||
|
|
||||||
|
return prediction, y
|
||||||
|
|
||||||
|
def parse_ql(self, q, game, intermediate):
|
||||||
|
tom12_answ = ['YES', 'NO', 'MAYBE']
|
||||||
|
materials_dict = game.materials_dict.copy()
|
||||||
|
materials_dict['NOT_SURE'] = 0
|
||||||
|
if not q is None:
|
||||||
|
q, l = q
|
||||||
|
tom_gt = np.concatenate([onehot(q[2],2), onehot(q[3],2)])
|
||||||
|
#### q1
|
||||||
|
q1_1 = np.concatenate([onehot(q[4][0][0]+1,2), onehot(materials_dict[q[4][0][1]], len(game.materials_dict))])
|
||||||
|
## r1
|
||||||
|
r1_1 = np.eye(len(tom12_answ))[tom12_answ.index(l[0][0])]
|
||||||
|
#### q2
|
||||||
|
q2_1 = np.concatenate([onehot(q[4][1][0]+1,2), onehot(materials_dict[q[4][1][1]], len(game.materials_dict))])
|
||||||
|
## r2
|
||||||
|
r2_1 = np.eye(len(tom12_answ))[tom12_answ.index(l[0][1])]
|
||||||
|
#### q3
|
||||||
|
q3_1 = onehot(q[4][2]+1, 2)
|
||||||
|
## r3
|
||||||
|
r3_1 = onehot(materials_dict[l[0][2]], len(game.materials_dict))
|
||||||
|
#### q1
|
||||||
|
q1_2 = np.concatenate([onehot(q[5][0][0]+1,2), onehot(materials_dict[q[5][0][1]], len(game.materials_dict))])
|
||||||
|
## r1
|
||||||
|
r1_2 = np.eye(len(tom12_answ))[tom12_answ.index(l[1][0])]
|
||||||
|
#### q2
|
||||||
|
q2_2 = np.concatenate([onehot(q[5][1][0]+1,2), onehot(materials_dict[q[5][1][1]], len(game.materials_dict))])
|
||||||
|
## r2
|
||||||
|
r2_2 = np.eye(len(tom12_answ))[tom12_answ.index(l[1][1])]
|
||||||
|
#### q3
|
||||||
|
q3_2 = onehot(q[5][2]+1,2)
|
||||||
|
## r3
|
||||||
|
r3_2 = onehot(materials_dict[l[1][2]], len(game.materials_dict))
|
||||||
|
if intermediate == 0:
|
||||||
|
tom_gt = np.zeros(154)
|
||||||
|
elif intermediate == 1:
|
||||||
|
# tom6
|
||||||
|
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, np.zeros(q2_1.shape[0] + r2_1.shape[0] + q3_1.shape[0] + r3_1.shape[0]), q1_2, r1_2, np.zeros(q2_2.shape[0] + r2_2.shape[0] + q3_2.shape[0] + r3_2.shape[0])])
|
||||||
|
elif intermediate == 2:
|
||||||
|
# tom7
|
||||||
|
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0]), q2_1, r2_1, np.zeros(q3_1.shape[0] + r3_1.shape[0] + q1_2.shape[0] + r1_2.shape[0]), q2_2, r2_2, np.zeros(q3_2.shape[0] + r3_2.shape[0])])
|
||||||
|
elif intermediate == 3:
|
||||||
|
# tom6 + tom7
|
||||||
|
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, q2_1, r2_1, np.zeros(q3_1.shape[0] + r3_1.shape[0]), q1_2, r1_2, q2_2, r2_2, np.zeros(q3_2.shape[0] + r3_2.shape[0])])
|
||||||
|
elif intermediate == 4:
|
||||||
|
# tom8
|
||||||
|
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0] + q2_1.shape[0] + r2_1.shape[0]), q3_1, r3_1, np.zeros(q1_2.shape[0] + r1_2.shape[0] + q2_2.shape[0] + r2_2.shape[0]), q3_2, r3_2])
|
||||||
|
elif intermediate == 5:
|
||||||
|
# tom6 + tom8
|
||||||
|
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, np.zeros(q2_1.shape[0] + r2_1.shape[0]), q3_1, r3_1, q1_2, r1_2, np.zeros(q2_2.shape[0] + r2_2.shape[0]), q3_2, r3_2])
|
||||||
|
elif intermediate == 6:
|
||||||
|
# tom7 + tom8
|
||||||
|
tom_gt = np.concatenate([tom_gt, np.zeros(q1_1.shape[0] + r1_1.shape[0]), q2_1, r2_1, q3_1, r3_1, np.zeros(q1_2.shape[0] + r1_2.shape[0]), q2_2, r2_2, q3_2, r3_2])
|
||||||
|
elif intermediate == 7:
|
||||||
|
# tom6 + tom7 + tom8
|
||||||
|
tom_gt = np.concatenate([tom_gt, q1_1, r1_1, q2_1, r2_1, q3_1, r3_1, q1_2, r1_2, q2_2, r2_2, q3_2, r3_2])
|
||||||
|
else:
|
||||||
|
tom_gt = np.zeros(154)
|
||||||
|
if tom_gt.shape[0] != 154: breakpoint()
|
||||||
|
return tom_gt
|
Loading…
Reference in a new issue