@@ -168,22 +168,58 @@ llvm::Value *get_block_header(
168
168
llvm::Type::getInt64Ty (module ->getContext ()), header_val));
169
169
}
170
170
171
+ template <typename T>
172
+ requires std::same_as<T, llvm::BasicBlock>
173
+ || std::same_as<T, llvm::Instruction>
171
174
llvm::Value *allocate_term (
172
- llvm::Type *alloc_type, llvm::BasicBlock *block, char const *alloc_fn) {
173
- return allocate_term (
174
- alloc_type, llvm::ConstantExpr::getSizeOf (alloc_type), block, alloc_fn);
175
- }
176
-
177
- llvm::Value *allocate_term (
178
- llvm::Type *alloc_type, llvm::Value *len, llvm::BasicBlock *block,
175
+ llvm::Type *alloc_type, llvm::Value *len, T *insert_point,
179
176
char const *alloc_fn) {
180
177
auto *malloc = create_malloc (
181
- block , len, kore_heap_alloc (alloc_fn, block ->getModule ()));
178
+ insert_point , len, kore_heap_alloc (alloc_fn, insert_point ->getModule ()));
182
179
183
180
set_debug_loc (malloc);
184
181
return malloc;
185
182
}
186
183
184
+ static bool is_basic_alloc (std::string const &alloc_fn) {
185
+ return alloc_fn == " kore_alloc" || alloc_fn == " kore_alloc_old"
186
+ || alloc_fn == " kore_alloc_always_gc" ;
187
+ }
188
+
189
+ llvm::Value *allocate_term (
190
+ llvm::Type *alloc_type, llvm::BasicBlock *block, char const *alloc_fn,
191
+ bool mergeable) {
192
+ llvm::DataLayout layout (block->getModule ());
193
+ auto type_size = layout.getTypeAllocSize (alloc_type).getFixedValue ();
194
+ auto *ty = llvm::Type::getInt64Ty (block->getContext ());
195
+ if (mergeable) {
196
+ if (auto *first = block->getFirstNonPHI ()) {
197
+ if (auto *call = llvm::dyn_cast<llvm::CallInst>(first)) {
198
+ if (auto *func = call->getCalledFunction ()) {
199
+ if (auto *size
200
+ = llvm::dyn_cast<llvm::ConstantInt>(call->getOperand (0 ))) {
201
+ if (func->getName () == alloc_fn && is_basic_alloc (alloc_fn)
202
+ && size->getLimitedValue () + type_size < max_block_merge_size) {
203
+ call->setOperand (
204
+ 0 , llvm::ConstantExpr::getAdd (
205
+ size, llvm::ConstantInt::get (ty, type_size)));
206
+ auto *ret = llvm::GetElementPtrInst::Create (
207
+ llvm::Type::getInt8Ty (block->getContext ()), call, {size},
208
+ " alloc_chunk" , block);
209
+ set_debug_loc (ret);
210
+ return ret;
211
+ }
212
+ }
213
+ }
214
+ }
215
+ }
216
+ return allocate_term (
217
+ alloc_type, llvm::ConstantInt::get (ty, type_size), block, alloc_fn);
218
+ }
219
+ return allocate_term (
220
+ alloc_type, llvm::ConstantInt::get (ty, type_size), block, alloc_fn);
221
+ }
222
+
187
223
value_type term_type (
188
224
kore_pattern *pattern, llvm::StringMap<value_type> &substitution,
189
225
kore_definition *definition) {
@@ -686,7 +722,8 @@ llvm::Value *create_term::create_function_call(
686
722
// we don't use alloca here because the tail call optimization pass for llvm
687
723
// doesn't handle correctly functions with alloca
688
724
alloc_sret = allocate_term (
689
- return_type, current_block_, get_collection_alloc_fn (return_cat.cat ));
725
+ return_type, current_block_, get_collection_alloc_fn (return_cat.cat ),
726
+ true );
690
727
sret_type = return_type;
691
728
real_args.insert (real_args.begin (), alloc_sret);
692
729
types.insert (types.begin (), alloc_sret->getType ());
@@ -759,7 +796,8 @@ llvm::Value *create_term::not_injection_case(
759
796
children.push_back (child_value);
760
797
idx++;
761
798
}
762
- llvm::Value *block = allocate_term (block_type, current_block_);
799
+ llvm::Value *block
800
+ = allocate_term (block_type, current_block_, " kore_alloc" , true );
763
801
llvm::Value *block_header_ptr = llvm::GetElementPtrInst::CreateInBounds (
764
802
block_type, block,
765
803
{llvm::ConstantInt::get (llvm::Type::getInt64Ty (ctx_), 0 ),
@@ -1162,7 +1200,7 @@ std::string make_apply_rule_function(
1162
1200
if (!arg->getType ()->isPointerTy ()) {
1163
1201
auto *ptr = allocate_term (
1164
1202
arg->getType (), creator.get_current_block (),
1165
- get_collection_alloc_fn (cat.cat ));
1203
+ get_collection_alloc_fn (cat.cat ), true );
1166
1204
new llvm::StoreInst (arg, ptr, creator.get_current_block ());
1167
1205
arg = ptr;
1168
1206
}
0 commit comments