|
| 1 | +#include <iostream> |
| 2 | +#include <fstream> |
| 3 | +#include <string> |
| 4 | +#include <cstring> |
| 5 | +#include <vector> |
| 6 | +#include <cstdint> |
| 7 | +#include "../examples/yaml_qp_parser.h" |
| 8 | + |
| 9 | +// Write raw array as .npy format (little-endian, no compression) |
| 10 | +template<typename T> |
| 11 | +void write_npy(const std::string &filename, const T *data, int64_t len) { |
| 12 | + std::ofstream fout(filename, std::ios::binary); |
| 13 | + if (!fout) { |
| 14 | + std::cerr << "Failed to open file: " << filename << "\n"; |
| 15 | + return; |
| 16 | + } |
| 17 | + |
| 18 | + std::string header = |
| 19 | + "{'descr': '" + std::string((sizeof(T) == 8 ? "<f8" : "<i4")) + |
| 20 | + "', 'fortran_order': False, 'shape': (" + std::to_string(len) + ",), }"; |
| 21 | + |
| 22 | + size_t pad_len = 16 - (10 + header.size() + 1) % 16; |
| 23 | + header += std::string(pad_len, ' '); |
| 24 | + header += '\n'; |
| 25 | + |
| 26 | + fout.write("\x93NUMPY", 6); |
| 27 | + fout.put(0x01); fout.put(0x00); |
| 28 | + uint16_t hlen = static_cast<uint16_t>(header.size()); |
| 29 | + fout.write(reinterpret_cast<char *>(&hlen), 2); |
| 30 | + fout.write(header.c_str(), header.size()); |
| 31 | + fout.write(reinterpret_cast<const char *>(data), sizeof(T) * len); |
| 32 | + fout.close(); |
| 33 | +} |
| 34 | + |
| 35 | +void export_csc(const nasoq::CSC *M, const std::string &prefix) { |
| 36 | + int ncol = M->ncol; |
| 37 | + int nnz = M->p[ncol]; |
| 38 | + |
| 39 | + write_npy(prefix + "_p.npy", M->p, ncol + 1); |
| 40 | + write_npy(prefix + "_i.npy", M->i, nnz); |
| 41 | + write_npy(prefix + "_x.npy", M->x, nnz); |
| 42 | +} |
| 43 | + |
| 44 | +int main(int argc, char **argv) { |
| 45 | + if (argc < 2) { |
| 46 | + std::cerr << "Usage: ./export_npy <qp_file.yml>\n"; |
| 47 | + return 1; |
| 48 | + } |
| 49 | + |
| 50 | + std::string fname = argv[1]; |
| 51 | + QPProblem qp; |
| 52 | + if (!parse_qp_yaml(fname, qp)) { |
| 53 | + std::cerr << "Failed to parse the .yml file.\n"; |
| 54 | + return 2; |
| 55 | + } |
| 56 | + |
| 57 | + std::cout << "Exporting parsed QP...\n"; |
| 58 | + |
| 59 | + export_csc(qp.H, "H_csc"); |
| 60 | + write_npy("q.npy", qp.q, qp.n); |
| 61 | + |
| 62 | + if (qp.me > 0 && qp.A && qp.b) { |
| 63 | + export_csc(qp.A, "A_csc"); |
| 64 | + write_npy("b.npy", qp.b, qp.me); |
| 65 | + } |
| 66 | + |
| 67 | + if (qp.mi > 0 && qp.C && qp.l && qp.u) { |
| 68 | + export_csc(qp.C, "C_csc"); |
| 69 | + write_npy("l.npy", qp.l, qp.mi); |
| 70 | + write_npy("u.npy", qp.u, qp.mi); |
| 71 | + } |
| 72 | + |
| 73 | + std::cout << "Output written as .npy files.\n"; |
| 74 | + return 0; |
| 75 | +} |
0 commit comments