Skip to content

Active learner save method does not work for SetFit #71

@jakobstgl

Description

@jakobstgl

Bug description

I set up a Classification Factory for a SetFit Model, initialized the PoolBasedActiveLearner and started the Query Process successfully. However, when trying to save the Active Learner to disk for later usage, I ran into TypeError: cannot pickle 'ConfigModuleInstance' object

Steps to reproduce

def initialize_active_learner(active_learner, y_train):
    indices_initial = range(0,50)
    active_learner.initialize(np.array(indices_initial))
    
    return indices_initial


num_classes = 2
sentence_transformer_model_name = 'sentence-transformers/paraphrase-mpnet-base-v2'
setfit_model_args = SetFitModelArguments(sentence_transformer_model_name)
clf_factory = SetFitClassificationFactory(
    setfit_model_args,
    num_classes,
    classification_kwargs = {
    'device': 'cuda',
    'max_seq_len': 64,
    'mini_batch_size': 8
}
)
active_learner = PoolBasedActiveLearner(clf_factory, query_strategy, dataset = dataset)

query_strategy = SubsamplingQueryStrategy(BreakingTies(), subsample_size=10000) 
active_learner = PoolBasedActiveLearner(clf_factory, query_strategy, dataset = dataset)

indices_initial = initialize_active_learner(active_learner, dataset.y)
indices_labeled = indices_initial

active_learner.save("test.pkl")

Expected behavior

Serialize and save the active learner in the same way as it works with a regular transformer-based active learner.

Environment:

Python version: 3.9.6
small-text version: 2.0.0.dev1
small-text integrations (e.g., transformers): setfit 1.1.0, transformers 4.45.2
PyTorch version (if applicable): 2.5.0+cu124

Installation (pip, conda, or from source): pip
CUDA version (if applicable): 12.6

Addition information

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions