from skorch import NeuralNetClassifier
model = NeuralNetClassifier(
module=MyClassifier, # Класс модели на PyTorch
lr=0.001, # Скорость обучения
batch_size=64, # Размер батча
criterion=nn.CrossEntropyLoss, # Функция потерь
optimizer=optim.Adam # Оптимизатор
)
Здесь создаётся обёртка NeuralNetClassifier, которая делает модель PyTorch совместимой с .fit(), .predict() и другими методами Sklearn.
📌Обучение:
model.fit(X_train, y_train)
Ты обучаешь модель так же, как и в Sklearn. Это удобно и не требует написания собственного цикла обучения.
С помощью Skorch ты получаешь:
- удобный Sklearn-подобный API для PyTorch-моделей;
- автоматический вывод метрик обучения;
- лёгкую интеграцию с GridSearchCV, Pipeline и другими инструментами Scikit-learn.
https://github.com/skorch-dev/skorch
@machinelearning_interview