Skip to content

Commit d60f6bf

Browse files
committed
v2.0.0
1 parent 3336d77 commit d60f6bf

File tree

13 files changed

+1274
-291
lines changed

13 files changed

+1274
-291
lines changed

app/build.gradle.kts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ android {
1212
minSdk = 28
1313
// minSdk = 31
1414
targetSdk = 36
15-
versionCode = 46
16-
versionName = "1.9.1"
15+
versionCode = 47
16+
versionName = "2.0.0"
1717

1818
testInstrumentationRunner = "androidx.test.runner.AndroidJUnitRunner"
1919
vectorDrawables {
-243 KB
Binary file not shown.
156 KB
Binary file not shown.
144 KB
Binary file not shown.
-165 KB
Binary file not shown.
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
#include <algorithm>
2+
#include <cctype>
3+
#include <filesystem>
4+
#include <fstream>
5+
#include <map>
6+
#include <memory>
7+
#include <sstream>
8+
#include <stack>
9+
#include <stdexcept>
10+
#include <string>
11+
#include <vector>
12+
13+
#include "SafeTensorReader.hpp"
14+
15+
struct PromptToken {
16+
std::string text;
17+
float weight;
18+
bool is_embedding;
19+
std::vector<float> embedding_data;
20+
};
21+
22+
class PromptProcessor {
23+
private:
24+
std::map<std::string, std::vector<float>> embeddings_;
25+
std::string embeddings_dir_;
26+
27+
static std::string toLowerCase(const std::string& str) {
28+
std::string result = str;
29+
std::transform(result.begin(), result.end(), result.begin(),
30+
[](unsigned char c) { return std::tolower(c); });
31+
return result;
32+
}
33+
34+
static std::string trim(const std::string& str) {
35+
size_t start = str.find_first_not_of(" \t\r\n");
36+
if (start == std::string::npos) return "";
37+
size_t end = str.find_last_not_of(" \t\r\n");
38+
return str.substr(start, end - start + 1);
39+
}
40+
41+
struct TokenNode {
42+
std::string text;
43+
float weight;
44+
std::vector<TokenNode> children;
45+
bool is_group;
46+
47+
TokenNode() : weight(1.0f), is_group(false) {}
48+
};
49+
50+
TokenNode parsePromptTree(const std::string& prompt) {
51+
TokenNode root;
52+
root.is_group = true;
53+
root.weight = 1.0f;
54+
std::stack<TokenNode*> node_stack;
55+
node_stack.push(&root);
56+
57+
std::string current_text;
58+
size_t i = 0;
59+
60+
while (i < prompt.length()) {
61+
char c = prompt[i];
62+
63+
if (c == ' ' || c == '\t' || c == '\r' || c == '\n') {
64+
if (!current_text.empty() && i + 1 < prompt.length()) {
65+
char next = prompt[i + 1];
66+
if (next != '(' && next != ')' && next != '[' && next != ']' &&
67+
next != ',' && next != ' ' && next != '\t') {
68+
current_text += ' ';
69+
}
70+
}
71+
i++;
72+
continue;
73+
}
74+
75+
if (c == '(') {
76+
if (!current_text.empty()) {
77+
TokenNode text_node;
78+
text_node.text = trim(current_text);
79+
text_node.weight = 1.0f;
80+
text_node.is_group = false;
81+
if (!text_node.text.empty()) {
82+
node_stack.top()->children.push_back(text_node);
83+
}
84+
current_text.clear();
85+
}
86+
87+
TokenNode* parent = node_stack.top();
88+
parent->children.push_back(TokenNode());
89+
TokenNode* new_node = &parent->children.back();
90+
new_node->is_group = true;
91+
new_node->weight = 1.1f;
92+
node_stack.push(new_node);
93+
i++;
94+
95+
} else if (c == ')') {
96+
if (!current_text.empty()) {
97+
size_t colon_pos = current_text.rfind(':');
98+
bool has_weight = false;
99+
100+
if (colon_pos != std::string::npos && node_stack.size() > 1 &&
101+
node_stack.top()->is_group) {
102+
std::string weight_str = trim(current_text.substr(colon_pos + 1));
103+
std::string text_part = trim(current_text.substr(0, colon_pos));
104+
105+
try {
106+
float weight = std::stof(weight_str);
107+
TokenNode text_node;
108+
text_node.text = text_part;
109+
text_node.weight = weight;
110+
text_node.is_group = false;
111+
if (!text_node.text.empty()) {
112+
node_stack.top()->children.push_back(text_node);
113+
}
114+
has_weight = true;
115+
} catch (...) {
116+
// failed to parse weight
117+
}
118+
}
119+
120+
if (!has_weight) {
121+
TokenNode text_node;
122+
text_node.text = trim(current_text);
123+
text_node.weight = 1.0f;
124+
text_node.is_group = false;
125+
if (!text_node.text.empty()) {
126+
node_stack.top()->children.push_back(text_node);
127+
}
128+
}
129+
current_text.clear();
130+
}
131+
132+
if (node_stack.size() > 1) {
133+
node_stack.pop();
134+
}
135+
i++;
136+
137+
} else if (c == '[') {
138+
if (!current_text.empty()) {
139+
TokenNode text_node;
140+
text_node.text = trim(current_text);
141+
text_node.weight = 1.0f;
142+
text_node.is_group = false;
143+
if (!text_node.text.empty()) {
144+
node_stack.top()->children.push_back(text_node);
145+
}
146+
current_text.clear();
147+
}
148+
149+
TokenNode* parent = node_stack.top();
150+
parent->children.push_back(TokenNode());
151+
TokenNode* new_node = &parent->children.back();
152+
new_node->is_group = true;
153+
new_node->weight = 0.9f;
154+
node_stack.push(new_node);
155+
i++;
156+
157+
} else if (c == ']') {
158+
if (!current_text.empty()) {
159+
TokenNode text_node;
160+
text_node.text = trim(current_text);
161+
text_node.weight = 1.0f;
162+
text_node.is_group = false;
163+
if (!text_node.text.empty()) {
164+
node_stack.top()->children.push_back(text_node);
165+
}
166+
current_text.clear();
167+
}
168+
169+
if (node_stack.size() > 1) {
170+
node_stack.pop();
171+
}
172+
i++;
173+
174+
} else if (c == ',') {
175+
if (!current_text.empty()) {
176+
TokenNode text_node;
177+
text_node.text = trim(current_text);
178+
text_node.weight = 1.0f;
179+
text_node.is_group = false;
180+
if (!text_node.text.empty()) {
181+
node_stack.top()->children.push_back(text_node);
182+
}
183+
current_text.clear();
184+
}
185+
TokenNode comma_node;
186+
comma_node.text = ",";
187+
comma_node.weight = 1.0f;
188+
comma_node.is_group = false;
189+
node_stack.top()->children.push_back(comma_node);
190+
i++;
191+
192+
} else {
193+
current_text += c;
194+
i++;
195+
}
196+
}
197+
198+
if (!current_text.empty()) {
199+
TokenNode text_node;
200+
text_node.text = trim(current_text);
201+
text_node.weight = 1.0f;
202+
text_node.is_group = false;
203+
if (!text_node.text.empty()) {
204+
node_stack.top()->children.push_back(text_node);
205+
}
206+
}
207+
208+
return root;
209+
}
210+
211+
void flattenTree(const TokenNode& node, float parent_weight,
212+
std::vector<PromptToken>& tokens) {
213+
float current_weight = parent_weight * node.weight;
214+
215+
if (node.is_group) {
216+
for (const auto& child : node.children) {
217+
flattenTree(child, current_weight, tokens);
218+
}
219+
} else {
220+
if (!node.text.empty()) {
221+
std::string text_lower = toLowerCase(node.text);
222+
223+
if (embeddings_.find(text_lower) != embeddings_.end()) {
224+
tokens.push_back(
225+
{node.text, current_weight, true, embeddings_[text_lower]});
226+
} else {
227+
tokens.push_back({node.text, current_weight, false, {}});
228+
}
229+
}
230+
}
231+
}
232+
233+
public:
234+
PromptProcessor() = default;
235+
236+
void loadEmbeddings(const std::string& embeddings_dir) {
237+
embeddings_dir_ = embeddings_dir;
238+
embeddings_.clear();
239+
240+
if (!std::filesystem::exists(embeddings_dir)) {
241+
return;
242+
}
243+
244+
for (const auto& entry :
245+
std::filesystem::directory_iterator(embeddings_dir)) {
246+
if (entry.path().extension() == ".safetensors") {
247+
try {
248+
SafeTensorReader reader(entry.path().string());
249+
std::string name = entry.path().stem().string();
250+
std::string name_lower = toLowerCase(name);
251+
252+
auto tensor_names = reader.get_tensor_names();
253+
if (!tensor_names.empty()) {
254+
reader.read(tensor_names[0], true);
255+
embeddings_[name_lower] = reader.data;
256+
}
257+
} catch (const std::exception& e) {
258+
// could not load this embedding
259+
}
260+
}
261+
}
262+
}
263+
264+
std::vector<PromptToken> process(const std::string& prompt) {
265+
std::vector<PromptToken> tokens;
266+
267+
TokenNode tree = parsePromptTree(prompt);
268+
269+
flattenTree(tree, 1.0f, tokens);
270+
271+
return tokens;
272+
}
273+
274+
size_t getEmbeddingCount() const { return embeddings_.size(); }
275+
276+
bool hasEmbedding(const std::string& name) const {
277+
return embeddings_.find(toLowerCase(name)) != embeddings_.end();
278+
}
279+
};

app/src/main/cpp/src/SafeTensor2MNN.hpp

Lines changed: 12 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -348,71 +348,32 @@ void generateClipModel(const std::string& dir,
348348
const std::vector<std::string>& loras = {},
349349
const std::vector<float>& lora_weights = {}) {
350350
if (clip_skip_2) {
351-
generateModel(dir, safetensor_file, "clip", clip_skip_2_structure, loras,
351+
generateModel(dir, safetensor_file, "clip_v2", clip_skip_2_structure, loras,
352352
lora_weights);
353353
} else {
354-
generateModel(dir, safetensor_file, "clip", clip_structure, loras,
354+
generateModel(dir, safetensor_file, "clip_v2", clip_structure, loras,
355355
lora_weights);
356356
}
357357

358-
int header_size = 246656;
359-
int middle_size = 2256;
360-
if (clip_skip_2) {
361-
header_size = 167888;
362-
middle_size = 888;
363-
}
364-
365-
auto filename = dir + "/clip.mnn.slimmed";
366-
if (clip_skip_2) {
367-
filename = dir + "/clip_skip_2.mnn.slimmed";
368-
}
369-
370-
std::ifstream slimmed_file(filename, std::ios::binary);
371-
slimmed_file.seekg(0, std::ios::end);
372-
int slimmed_size = slimmed_file.tellg();
373-
slimmed_file.seekg(0, std::ios::beg);
374-
std::vector<uint8_t> slimmed_data(slimmed_size);
375-
slimmed_file.read(reinterpret_cast<char*>(slimmed_data.data()), slimmed_size);
376-
slimmed_file.close();
377-
378358
SafeTensorReader reader(dir + "/" + safetensor_file);
379359

380360
reader.read(
381361
"cond_stage_model.transformer.text_model.embeddings.position_embedding."
382362
"weight",
383-
false);
384-
std::vector<uint8_t> pos_emb_bytes(reader.fp16_data.size() *
385-
sizeof(uint16_t));
386-
std::memcpy(pos_emb_bytes.data(), reader.fp16_data.data(),
387-
pos_emb_bytes.size());
363+
true);
364+
std::ofstream pos_emb_file(dir + "/pos_emb.bin", std::ios::binary);
365+
pos_emb_file.write(reinterpret_cast<const char*>(reader.data.data()),
366+
reader.data.size() * sizeof(float));
367+
pos_emb_file.close();
388368

389369
reader.read(
390370
"cond_stage_model.transformer.text_model.embeddings.token_embedding."
391371
"weight",
392-
false);
393-
std::vector<uint8_t> token_emb_bytes(reader.fp16_data.size() *
394-
sizeof(uint16_t));
395-
std::memcpy(token_emb_bytes.data(), reader.fp16_data.data(),
396-
token_emb_bytes.size());
397-
398-
std::vector<uint8_t> header(slimmed_data.begin(),
399-
slimmed_data.begin() + header_size);
400-
std::vector<uint8_t> middle(slimmed_data.begin() + header_size,
401-
slimmed_data.begin() + header_size + middle_size);
402-
std::vector<uint8_t> tail(slimmed_data.begin() + header_size + middle_size,
403-
slimmed_data.end());
404-
405-
std::ofstream output_file(dir + "/clip.mnn", std::ios::binary);
406-
output_file.write(reinterpret_cast<const char*>(header.data()),
407-
header.size());
408-
output_file.write(reinterpret_cast<const char*>(pos_emb_bytes.data()),
409-
pos_emb_bytes.size());
410-
output_file.write(reinterpret_cast<const char*>(middle.data()),
411-
middle.size());
412-
output_file.write(reinterpret_cast<const char*>(token_emb_bytes.data()),
413-
token_emb_bytes.size());
414-
output_file.write(reinterpret_cast<const char*>(tail.data()), tail.size());
415-
output_file.close();
372+
true);
373+
std::ofstream token_emb_file(dir + "/token_emb.bin", std::ios::binary);
374+
token_emb_file.write(reinterpret_cast<const char*>(reader.data.data()),
375+
reader.data.size() * sizeof(float));
376+
token_emb_file.close();
416377
}
417378

418379
void generateMNNModels(const std::string& dir,

app/src/main/cpp/src/SafeTensorReader.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
#ifndef SAFE_TENSOR_READER_HPP
2+
#define SAFE_TENSOR_READER_HPP
3+
14
#include <cstring>
25
#include <fstream>
36
#include <map>
@@ -205,4 +208,6 @@ class SafeTensorReader {
205208
}
206209

207210
int get_tensor_count() const { return tensor_map_.size(); }
208-
};
211+
};
212+
213+
#endif // SAFE_TENSOR_READER_HPP

0 commit comments

Comments
 (0)