19 lines
805 B
Python
19 lines
805 B
Python
from typing import Optional, Tuple
|
|
import torch
|
|
from transformers.modeling_outputs import ModelOutput
|
|
from dataclasses import dataclass
|
|
|
|
|
|
@dataclass
|
|
class Seq2SeqV2DialOutput(ModelOutput):
|
|
loss: Optional[torch.FloatTensor] = None
|
|
logits: torch.FloatTensor = None
|
|
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
|
|
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
|
|
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
|
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
|
|
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None
|
|
|