Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ContrastiveDataset and ContrastiveDistillationDataset #579

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
refactor attempt for ContrastiveDataset so that it does not blow RAM …
…with bigger dataset
  • Loading branch information
DemirTonchev committed Dec 23, 2024
commit 6a6930383ebb332ecd944de95c4527af0ab61a1a
108 changes: 107 additions & 1 deletion src/setfit/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
logger = logging.get_logger(__name__)


def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Generator:
def shuffle_combinations(iterable: Iterable, replacement: bool = False) -> Generator:
"""Generates shuffled pair combinations for any iterable data provided.

Args:
Expand All @@ -31,6 +31,112 @@ def shuffle_combinations(iterable: Iterable, replacement: bool = True) -> Genera
yield iterable[_idx], iterable[idx]


class ContrastiveDatasetIt(IterableDataset):
def __init__(
self,
sentences: List[str],
labels: List[Union[int, float]],
multilabel: bool = False, # False for now
num_iterations: Optional[None] = None,
sampling_strategy: str = "oversampling",
max_pairs: int = -1,
) -> None:
"""Generates positive and negative text pairs for contrastive learning.

Args:
sentences (List[str]): text sentences to generate pairs from
labels (List[Union[int, float]]): labels for each sentence
multilabel: set to process "multilabel" labels array
sampling_strategy: "unique", "oversampling", or "undersampling"
num_iterations: if provided explicitly sets the number of pairs to be generated
where n_pairs = n_iterations * n_sentences * 2 (for pos & neg pairs)
max_pairs: If not -1, then we only sample pairs until we have certainly reached
max_pairs pairs.
"""
super().__init__()
self.pos_index = 0
self.neg_index = 0
self.pos_pairs = []
self.neg_pairs = []
self.sentences = sentences
self.labels = labels
self.sentence_labels = list(zip(self.sentences, self.labels))
self.max_pos_or_neg = np.inf if max_pairs == -1 else max_pairs // 2

from collections import Counter
from math import prod
label_counts = Counter(labels)
# postive number of pairs from an n element set without replacement
self.total_pos_pairs = int(sum([n * (n - 1) / 2 for n in label_counts.values()]))
# negative product
self.total_neg_pairs = prod(label_counts.values())

self.generated_pos_pairs = 0
self.generated_neg_pairs = 0

if num_iterations is not None and num_iterations > 0:
self.len_pos_pairs = num_iterations * len(self.sentences)
self.len_neg_pairs = num_iterations * len(self.sentences)

elif sampling_strategy == "unique":
self.len_pos_pairs = int(np.min([self.total_pos_pairs, self.max_pos_or_neg]))
self.len_neg_pairs = int(np.min([self.total_neg_pairs, self.max_pos_or_neg]))

elif sampling_strategy == "undersampling":
self.len_pos_pairs = int(np.min([min(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg]))
self.len_neg_pairs = int(np.min([min(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg]))

elif sampling_strategy == "oversampling":
self.len_pos_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg]))
self.len_neg_pairs = int(np.min([max(self.total_pos_pairs, self.total_neg_pairs), self.max_pos_or_neg]))

else:
raise ValueError("Invalid sampling strategy. Must be one of 'unique', 'oversampling', or 'undersampling'.")

# generate pair functions are not ideal but still wont blow the memory if you decide to train on big dataset
def generate_positive_pair(self):

pair_generator = shuffle_combinations(self.sentence_labels)
while True:
for (_text, _label), (text, label) in pair_generator:
is_positive = _label == label

if is_positive and self.generated_pos_pairs <= self.len_pos_pairs:
self.generated_pos_pairs += 1
yield {"sentence_1": _text, "sentence_2": text, "label": 1.0}
# restart
pair_generator = shuffle_combinations(self.sentence_labels)

def generate_negative_pair(self):
pair_generator = shuffle_combinations(self.sentence_labels)
while True:
for (_text, _label), (text, label) in pair_generator:
is_negative = _label != label

if is_negative and self.generated_neg_pairs <= self.len_neg_pairs:
self.generated_neg_pairs += 1
yield {"sentence_1": _text, "sentence_2": text, "label": 0.0}
# restart
pair_generator = shuffle_combinations(self.sentence_labels)

def __iter__(self):
# reset to starting values(state) so that iterator can be recreated and used again if needed.
self.generated_pos_pairs = 0
self.generated_neg_pairs = 0

pos_generator = self.generate_positive_pair()
neg_generator = self.generate_negative_pair()

while (self.generated_pos_pairs + self.generated_neg_pairs) < len(self):
if self.generated_pos_pairs < self.len_pos_pairs:
yield next(pos_generator)
if self.generated_neg_pairs < self.len_neg_pairs:
yield next(neg_generator)

def __len__(self) -> int:
return self.len_pos_pairs + self.len_neg_pairs


class ContrastiveDataset(IterableDataset):
def __init__(
self,
Expand Down