V2Dial/models/backbones/encoder_decoder/outputs.py
2025-06-24 08:38:09 +02:00

19 lines
805 B
Python

from typing import Optional, Tuple
import torch
from transformers.modeling_outputs import ModelOutput
from dataclasses import dataclass
@dataclass
class Seq2SeqV2DialOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_outputs: Optional[Tuple[torch.FloatTensor]] = None