Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion atomgen/models/configuration_atomformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers.configuration_utils import PretrainedConfig


class AtomformerConfig(PretrainedConfig):
class AtomformerConfig(PretrainedConfig): # type: ignore[no-untyped-call]
r"""
Configuration of a :class:`~transform:class:`~transformers.AtomformerModel`.

Expand Down
10 changes: 5 additions & 5 deletions atomgen/models/modeling_atomformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2550,7 +2550,7 @@ def forward(
class AtomformerPreTrainedModel(PreTrainedModel): # type: ignore[no-untyped-call]
"""Base class for all transformer models."""

config_class = AtomformerConfig # type: ignore[assignment]
config_class = AtomformerConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ParallelBlock"]
Expand Down Expand Up @@ -2968,9 +2968,9 @@ def __init__(self, config: AtomformerConfig):

if self.problem_type == "regression":
self.loss_fct = nn.L1Loss()
elif self.problem_type == "classification":
elif self.problem_type == "classification": # type: ignore[comparison-overlap]
self.loss_fct = nn.BCEWithLogitsLoss()
elif self.problem_type == "multiclass_classification":
elif self.problem_type == "multiclass_classification": # type: ignore[comparison-overlap]
self.loss_fct = nn.CrossEntropyLoss()

def forward(
Expand All @@ -2989,9 +2989,9 @@ def forward(

loss = None
if labels is not None:
if self.problem_type == "multiclass_classification":
if self.problem_type == "multiclass_classification": # type: ignore[comparison-overlap]
labels = labels.long()
elif self.problem_type == "classification":
elif self.problem_type == "classification": # type: ignore[comparison-overlap]
labels = labels.float()

loss = self.loss_fct(pred.squeeze(), labels.squeeze())
Expand Down
4 changes: 2 additions & 2 deletions atomgen/models/schnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from transformers.modeling_utils import PreTrainedModel


class SchNetConfig(PretrainedConfig):
class SchNetConfig(PretrainedConfig): # type: ignore[no-untyped-call]
r"""
Stores the configuration of a :class:`~transformers.SchNetModel`.

Expand Down Expand Up @@ -134,7 +134,7 @@ class SchNetPreTrainedModel(PreTrainedModel): # type: ignore[no-untyped-call]
simple interface for loading and exporting models.
"""

config_class = SchNetConfig # type: ignore[assignment]
config_class = SchNetConfig
base_model_prefix = "model"
supports_gradient_checkpointing = False

Expand Down
4 changes: 2 additions & 2 deletions atomgen/models/tokengt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,7 +2329,7 @@ def forward(
return out


class TransformerConfig(PretrainedConfig):
class TransformerConfig(PretrainedConfig): # type: ignore[no-untyped-call]
"""Configuration class to store the configuration of a TokenGT model."""

def __init__(
Expand Down Expand Up @@ -2510,7 +2510,7 @@ def custom_forward(*inputs: Any) -> Any:
class TransformerPreTrainedModel(PreTrainedModel): # type: ignore[no-untyped-call]
"""Base class for all transformer models."""

config_class = TransformerConfig # type: ignore[assignment]
config_class = TransformerConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["ParallelBlock"]
Expand Down
2 changes: 1 addition & 1 deletion scripts/training/pretrain_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def train(args: argparse.Namespace) -> None:
wandb.login(key=os.environ["WANDB_API_KEY"])
wandb.init(project=args.project, config=vars(args), name=args.name)

training_args = TrainingArguments(
training_args = TrainingArguments( # type: ignore[call-arg]
output_dir=args.output_dir,
learning_rate=args.learning_rate,
lr_scheduler_type=args.lr_scheduler_type,
Expand Down
Loading