Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions openprompt/plms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,30 @@ def balanced_truncate(input_dict: Dict,
num_tokens_to_truncate: int=0) -> Dict:
'''truncate the inputs with balance, number of cut tokens is proportional to the part's length.
'''
shortenable_lens = [len(parts) if parts[0]==1 else 0
for parts in input_dict['shortenable_ids']]
# Handling empty lists in 'shortenable_ids'
shortenable_lens = [len(parts) if parts and parts[0] == 1 else 0
for parts in input_dict['shortenable_ids']]
total_shortenable_len = sum(shortenable_lens)
num_tokens_to_truncate_each_part = [part_len/total_shortenable_len*num_tokens_to_truncate
for part_len in shortenable_lens]
# Handle empty truncation cases
if total_shortenable_len == 0:
num_tokens_to_truncate_each_part = [0] * len(shortenable_lens)
else:
num_tokens_to_truncate_each_part = [
part_len / total_shortenable_len * num_tokens_to_truncate
for part_len in shortenable_lens
]

round_list(num_tokens_to_truncate_each_part, num_tokens_to_truncate)

truncated_example = defaultdict(list)
for key in input_dict:
parts = input_dict[key]
for num_tokens_to_truncate_part, part in zip(num_tokens_to_truncate_each_part, parts):
truncated_example[key].append(part[:len(part)-num_tokens_to_truncate_part])
truncate_len = max(len(part) - num_tokens_to_truncate_part, 0)
truncated_example[key].append(part[:truncate_len])
# Filtering out empty sequences
for key in truncated_example:
truncated_example[key] = [part for part in truncated_example[key] if part]
return truncated_example

@staticmethod
Expand Down Expand Up @@ -155,6 +167,7 @@ def padding(input_dict: Dict,
max_len: int, pad_id_for_inputs: int=0, pad_id_for_others: int=0) -> None:
for key, value in input_dict.items():
if (len(input_dict[key]) > max_len):
continue
raise ValueError(f'''Truncated seq length of '{key}' still greater than max length {max_len}."\
"One possible reason is that no enough shortenable parts in template. Try adding {{"shortenable": "True"}} property.
''')
Expand Down