/* * testcrypt: a standalone test program that provides direct access to * PuTTY's cryptography and mp_int code. */ /* * This program speaks a line-oriented protocol on standard input and * standard output. It's a half-duplex protocol: it expects to read * one line of command, and then produce a fixed amount of output * (namely a line containing a decimal integer, followed by that many * lines each containing one return value). * * The protocol is human-readable enough to make it debuggable, but * verbose enough that you probably wouldn't want to speak it by hand * at any great length. The Python program test/testcrypt.py wraps it * to give a more useful user-facing API, by invoking this binary as a * subprocess. * * (I decided that was a better idea than making this program an * actual Python module, partly because you can rewrap the same binary * in another scripting language if you prefer, but mostly because * it's easy to attach a debugger to testcrypt or to run it under * sanitisers or valgrind or what have you.) */ #include #include #include #include #include #include "defs.h" #include "ssh.h" #include "misc.h" #include "mpint.h" #include "ecc.h" static NORETURN void fatal_error(const char *p, ...) { va_list ap; fprintf(stderr, "testcrypt: "); va_start(ap, p); vfprintf(stderr, p, ap); va_end(ap); fputc('\n', stderr); exit(1); } void out_of_memory(void) { fatal_error("out of memory"); } static bufchain random_data_queue; int random_byte(void) { unsigned char u; if (bufchain_try_fetch_consume(&random_data_queue, &u, 1)) return u; fatal_error("No random data in queue"); return 0; } #define VALUE_TYPES(X) \ X(string, strbuf *, strbuf_free(v)) \ X(mpint, mp_int *, mp_free(v)) \ X(modsqrt, ModsqrtContext *, modsqrt_free(v)) \ X(monty, MontyContext *, monty_free(v)) \ X(wcurve, WeierstrassCurve *, ecc_weierstrass_curve_free(v)) \ X(wpoint, WeierstrassPoint *, ecc_weierstrass_point_free(v)) \ X(mcurve, MontgomeryCurve *, ecc_montgomery_curve_free(v)) \ X(mpoint, MontgomeryPoint *, ecc_montgomery_point_free(v)) \ X(ecurve, EdwardsCurve *, ecc_edwards_curve_free(v)) \ X(epoint, EdwardsPoint *, ecc_edwards_point_free(v)) \ X(hash, ssh_hash *, ssh_hash_free(v)) \ X(key, ssh_key *, ssh_key_free(v)) \ X(ssh1_cipher, ssh1_cipher *, ssh1_cipher_free(v)) \ X(ssh2_cipher, ssh2_cipher *, ssh2_cipher_free(v)) \ X(mac, ssh2_mac *, ssh2_mac_free(v)) \ X(dh, dh_ctx *, dh_cleanup(v)) \ X(ecdh, ecdh_key *, ssh_ecdhkex_freekey(v)) \ X(rsakex, RSAKey *, ssh_rsakex_freekey(v)) \ X(rsa, RSAKey *, rsa_free(v)) \ /* end of list */ typedef struct Value Value; typedef enum ValueType ValueType; enum ValueType { #define VALTYPE_ENUM(n,t,f) VT_##n, VALUE_TYPES(VALTYPE_ENUM) #undef VALTYPE_ENUM }; const char *const type_names[] = { #define VALTYPE_NAME(n,t,f) #n, VALUE_TYPES(VALTYPE_NAME) #undef VALTYPE_NAME }; struct Value { /* * Protocol identifier assigned to this value when it was created. * Lives in the same malloced block as this Value object itself. */ ptrlen id; /* * Type of the value. */ ValueType type; /* * Union of all the things it could hold. */ union { #define VALTYPE_UNION(n,t,f) t vu_##n; VALUE_TYPES(VALTYPE_UNION) #undef VALTYPE_UNION }; }; static int valuecmp(void *av, void *bv) { Value *a = (Value *)av, *b = (Value *)bv; return ptrlen_strcmp(a->id, b->id); } static int valuefind(void *av, void *bv) { ptrlen *a = (ptrlen *)av; Value *b = (Value *)bv; return ptrlen_strcmp(*a, b->id); } static tree234 *values; static Value *value_new(ValueType vt) { static uint64_t next_index = 0; char *name = dupprintf("%s%"PRIu64, type_names[vt], next_index++); size_t namelen = strlen(name); Value *val = snew_plus(Value, namelen+1); memcpy(snew_plus_get_aux(val), name, namelen+1); val->id.ptr = snew_plus_get_aux(val); val->id.len = namelen; val->type = vt; Value *added = add234(values, val); assert(added == val); sfree(name); return val; } #define VALTYPE_RETURNFN(n,t,f) \ void return_val_##n(strbuf *out, t v) { \ Value *val = value_new(VT_##n); \ val->vu_##n = v; \ put_datapl(out, val->id); \ put_byte(out, '\n'); \ } VALUE_TYPES(VALTYPE_RETURNFN) #undef VALTYPE_RETURNFN static ptrlen get_word(BinarySource *in) { ptrlen toret; toret.ptr = get_ptr(in); toret.len = 0; while (get_avail(in) && get_byte(in) != ' ') toret.len++; return toret; } static const ssh_hashalg *get_hashalg(BinarySource *in) { static const struct { const char *key; const ssh_hashalg *value; } algs[] = { {"md5", &ssh_md5}, {"sha1", &ssh_sha1}, {"sha256", &ssh_sha256}, {"sha384", &ssh_sha384}, {"sha512", &ssh_sha512}, }; ptrlen name = get_word(in); for (size_t i = 0; i < lenof(algs); i++) if (ptrlen_eq_string(name, algs[i].key)) return algs[i].value; fatal_error("hashalg '%.*s': not found", PTRLEN_PRINTF(name)); } static const ssh2_macalg *get_macalg(BinarySource *in) { static const struct { const char *key; const ssh2_macalg *value; } algs[] = { {"hmac_md5", &ssh_hmac_md5}, {"hmac_sha1", &ssh_hmac_sha1}, {"hmac_sha1_buggy", &ssh_hmac_sha1_buggy}, {"hmac_sha1_96", &ssh_hmac_sha1_96}, {"hmac_sha1_96_buggy", &ssh_hmac_sha1_96_buggy}, {"hmac_sha256", &ssh_hmac_sha256}, {"poly1305", &ssh2_poly1305}, }; ptrlen name = get_word(in); for (size_t i = 0; i < lenof(algs); i++) if (ptrlen_eq_string(name, algs[i].key)) return algs[i].value; fatal_error("macalg '%.*s': not found", PTRLEN_PRINTF(name)); } static const ssh_keyalg *get_keyalg(BinarySource *in) { static const struct { const char *key; const ssh_keyalg *value; } algs[] = { {"dsa", &ssh_dss}, {"rsa", &ssh_rsa}, {"ed25519", &ssh_ecdsa_ed25519}, {"p256", &ssh_ecdsa_nistp256}, {"p384", &ssh_ecdsa_nistp384}, {"521", &ssh_ecdsa_nistp521}, }; ptrlen name = get_word(in); for (size_t i = 0; i < lenof(algs); i++) if (ptrlen_eq_string(name, algs[i].key)) return algs[i].value; fatal_error("keyalg '%.*s': not found", PTRLEN_PRINTF(name)); } static const ssh1_cipheralg *get_ssh1_cipheralg(BinarySource *in) { static const struct { const char *key; const ssh1_cipheralg *value; } algs[] = { {"3des", &ssh1_3des}, {"des", &ssh1_des}, {"blowfish", &ssh1_blowfish}, }; ptrlen name = get_word(in); for (size_t i = 0; i < lenof(algs); i++) if (ptrlen_eq_string(name, algs[i].key)) return algs[i].value; fatal_error("ssh1_cipheralg '%.*s': not found", PTRLEN_PRINTF(name)); } static const ssh2_cipheralg *get_ssh2_cipheralg(BinarySource *in) { static const struct { const char *key; const ssh2_cipheralg *value; } algs[] = { {"3des_ctr", &ssh_3des_ssh2_ctr}, {"3des", &ssh_3des_ssh2}, {"des", &ssh_des_ssh2}, {"des_sshcom", &ssh_des_sshcom_ssh2}, {"aes256_ctr", &ssh_aes256_ctr}, {"aes256", &ssh_aes256}, {"aes192_ctr", &ssh_aes192_ctr}, {"aes192", &ssh_aes192}, {"aes128_ctr", &ssh_aes128_ctr}, {"aes128", &ssh_aes128}, {"blowfish", &ssh_blowfish_ssh2_ctr}, {"blowfish", &ssh_blowfish_ssh2}, {"arcfour256", &ssh_arcfour256_ssh2}, {"arcfour128", &ssh_arcfour128_ssh2}, {"chacha20_poly1305", &ssh2_chacha20_poly1305}, }; ptrlen name = get_word(in); for (size_t i = 0; i < lenof(algs); i++) if (ptrlen_eq_string(name, algs[i].key)) return algs[i].value; fatal_error("ssh2_cipheralg '%.*s': not found", PTRLEN_PRINTF(name)); } static const struct ssh_kex *get_dh_group(BinarySource *in) { static const struct { const char *key; const struct ssh_kexes *value; } algs[] = { {"group1", &ssh_diffiehellman_group1}, {"group14", &ssh_diffiehellman_group14}, }; ptrlen name = get_word(in); for (size_t i = 0; i < lenof(algs); i++) if (ptrlen_eq_string(name, algs[i].key)) return algs[i].value->list[0]; fatal_error("dh_group '%.*s': not found", PTRLEN_PRINTF(name)); } static const struct ssh_kex *get_ecdh_alg(BinarySource *in) { static const struct { const char *key; const struct ssh_kex *value; } algs[] = { {"curve25519", &ssh_ec_kex_curve25519}, {"nistp256", &ssh_ec_kex_nistp256}, {"nistp384", &ssh_ec_kex_nistp384}, {"nistp521", &ssh_ec_kex_nistp521}, }; ptrlen name = get_word(in); for (size_t i = 0; i < lenof(algs); i++) if (ptrlen_eq_string(name, algs[i].key)) return algs[i].value; fatal_error("ecdh_alg '%.*s': not found", PTRLEN_PRINTF(name)); } static RsaSsh1Order get_rsaorder(BinarySource *in) { static const struct { const char *key; RsaSsh1Order value; } orders[] = { {"exponent_first", RSA_SSH1_EXPONENT_FIRST}, {"modulus_first", RSA_SSH1_MODULUS_FIRST}, }; ptrlen name = get_word(in); for (size_t i = 0; i < lenof(orders); i++) if (ptrlen_eq_string(name, orders[i].key)) return orders[i].value; fatal_error("rsaorder '%.*s': not found", PTRLEN_PRINTF(name)); } static uintmax_t get_uint(BinarySource *in) { ptrlen word = get_word(in); char *string = mkstr(word); uintmax_t toret = strtoumax(string, NULL, 0); sfree(string); return toret; } static Value *lookup_value(ptrlen word) { Value *val = find234(values, &word, valuefind); if (!val) fatal_error("id '%.*s': not found", PTRLEN_PRINTF(word)); return val; } static Value *get_value(BinarySource *in) { return lookup_value(get_word(in)); } typedef void (*finaliser_fn_t)(strbuf *out, void *ctx); struct finaliser { finaliser_fn_t fn; void *ctx; }; static struct finaliser *finalisers; size_t nfinalisers, finalisersize; static void add_finaliser(finaliser_fn_t fn, void *ctx) { if (nfinalisers >= finalisersize) { finalisersize = nfinalisers * 5 / 4 + 16; finalisers = sresize(finalisers, finalisersize, struct finaliser); } finalisers[nfinalisers].fn = fn; finalisers[nfinalisers].ctx = ctx; nfinalisers++; } static void run_finalisers(strbuf *out) { for (size_t i = 0; i < nfinalisers; i++) finalisers[i].fn(out, finalisers[i].ctx); nfinalisers = 0; } static void finaliser_return_value(strbuf *out, void *ctx) { Value *val = (Value *)ctx; put_datapl(out, val->id); put_byte(out, '\n'); } #define VALTYPE_GETFN(n,t,f) \ static Value *unwrap_value_##n(Value *val) { \ ValueType expected = VT_##n; \ if (expected != val->type) \ fatal_error("id '%.*s': expected %s, got %s", \ PTRLEN_PRINTF(val->id), \ type_names[expected], type_names[val->type]); \ return val; \ } \ static Value *get_value_##n(BinarySource *in) { \ return unwrap_value_##n(get_value(in)); \ } \ static t get_val_##n(BinarySource *in) { \ return get_value_##n(in)->vu_##n; \ } VALUE_TYPES(VALTYPE_GETFN) #undef VALTYPE_GETFN static ptrlen get_val_string_ptrlen(BinarySource *in) { return ptrlen_from_strbuf(get_val_string(in)); } static char *get_val_string_asciz(BinarySource *in) { return get_val_string(in)->s; } static mp_int **get_out_val_mpint(BinarySource *in) { Value *val = value_new(VT_mpint); add_finaliser(finaliser_return_value, val); return &val->vu_mpint; } static void finaliser_return_uint(strbuf *out, void *ctx) { unsigned *uval = (unsigned *)ctx; strbuf_catf(out, "%u\n", *uval); sfree(uval); } static unsigned *get_out_uint(BinarySource *in) { unsigned *uval = snew(unsigned); add_finaliser(finaliser_return_uint, uval); return uval; } static BinarySink *get_out_val_string_binarysink(BinarySource *in) { Value *val = value_new(VT_string); val->vu_string = strbuf_new(); add_finaliser(finaliser_return_value, val); return BinarySink_UPCAST(val->vu_string); } static void finaliser_sfree(strbuf *out, void *ctx) { sfree(ctx); } static BinarySource *get_val_string_binarysource(BinarySource *in) { strbuf *sb = get_val_string(in); BinarySource *src = snew(BinarySource); BinarySource_BARE_INIT(src, sb->u, sb->len); add_finaliser(finaliser_sfree, src); return src; } static ssh_hash *get_consumed_val_hash(BinarySource *in) { Value *val = get_value_hash(in); ssh_hash *toret = val->vu_hash; del234(values, val); sfree(val); return toret; } static void return_int(strbuf *out, intmax_t u) { strbuf_catf(out, "%"PRIdMAX"\n", u); } static void return_uint(strbuf *out, uintmax_t u) { strbuf_catf(out, "0x%"PRIXMAX"\n", u); } static void return_boolean(strbuf *out, bool b) { strbuf_catf(out, "%s\n", b ? "true" : "false"); } static void return_val_string_asciz(strbuf *out, char *s) { strbuf *sb = strbuf_new(); put_data(sb, s, strlen(s)); sfree(s); return_val_string(out, sb); } static void handle_hello(BinarySource *in, strbuf *out) { strbuf_catf(out, "hello, world"); } static void rsa_free(RSAKey *rsa) { freersakey(rsa); sfree(rsa); } static void free_value(Value *val) { switch (val->type) { #define VALTYPE_FREE(n,t,f) case VT_##n: { t v = val->vu_##n; (f); break; } VALUE_TYPES(VALTYPE_FREE) #undef VALTYPE_FREE } sfree(val); } static void handle_free(BinarySource *in, strbuf *out) { Value *val = get_value(in); del234(values, val); free_value(val); } static void handle_newstring(BinarySource *in, strbuf *out) { strbuf *sb = strbuf_new(); while (get_avail(in)) { char c = get_byte(in); if (c == '%') { char hex[3]; hex[0] = get_byte(in); if (hex[0] != '%') { hex[1] = get_byte(in); hex[2] = '\0'; c = strtoul(hex, NULL, 16); } } put_byte(sb, c); } return_val_string(out, sb); } static void handle_getstring(BinarySource *in, strbuf *out) { strbuf *sb = get_val_string(in); for (size_t i = 0; i < sb->len; i++) { char c = sb->s[i]; if (c > ' ' && c < 0x7F && c != '%') { put_byte(out, c); } else { strbuf_catf(out, "%%%02X", 0xFFU & (unsigned)c); } } put_byte(out, '\n'); } static void handle_mp_literal(BinarySource *in, strbuf *out) { ptrlen pl = get_word(in); char *str = mkstr(pl); mp_int *mp = mp__from_string_literal(str); sfree(str); return_val_mpint(out, mp); } static void handle_mp_dump(BinarySource *in, strbuf *out) { mp_int *mp = get_val_mpint(in); for (size_t i = mp_max_bytes(mp); i-- > 0 ;) strbuf_catf(out, "%02X", mp_get_byte(mp, i)); put_byte(out, '\n'); } static void random_queue(ptrlen pl) { bufchain_add(&random_data_queue, pl.ptr, pl.len); } static size_t random_queue_len(void) { return bufchain_size(&random_data_queue); } static void random_clear(void) { bufchain_clear(&random_data_queue); } #define WRAP_monty_identity , mp_int *monty_identity_wrapper(MontyContext *mc) { return mp_copy(monty_identity(mc)); } #define WRAP_monty_modulus , mp_int *monty_modulus_wrapper(MontyContext *mc) { return mp_copy(monty_modulus(mc)); } #define WRAP_ssh_hash_final , strbuf *ssh_hash_final_wrapper(ssh_hash *h) { strbuf *sb = strbuf_new(); void *p = strbuf_append(sb, ssh_hash_alg(h)->hlen); ssh_hash_final(h, p); return sb; } #define WRAP_ssh1_cipher_sesskey , void ssh1_cipher_sesskey_wrapper(ssh1_cipher *c, ptrlen key) { if (key.len != 32) fatal_error("ssh1_cipher_sesskey: needs exactly 32 bytes"); ssh1_cipher_sesskey(c, key.ptr); } #define WRAP_ssh1_cipher_encrypt , strbuf *ssh1_cipher_encrypt_wrapper(ssh1_cipher *c, ptrlen input) { if (input.len % c->vt->blksize) fatal_error("ssh1_cipher_encrypt: needs a multiple of %d bytes", c->vt->blksize); strbuf *sb = strbuf_new(); put_datapl(sb, input); ssh1_cipher_encrypt(c, sb->u, sb->len); return sb; } #define WRAP_ssh1_cipher_decrypt , strbuf *ssh1_cipher_decrypt_wrapper(ssh1_cipher *c, ptrlen input) { if (input.len % c->vt->blksize) fatal_error("ssh1_cipher_decrypt: needs a multiple of %d bytes", c->vt->blksize); strbuf *sb = strbuf_new(); put_datapl(sb, input); ssh1_cipher_decrypt(c, sb->u, sb->len); return sb; } #define WRAP_ssh2_cipher_setiv , void ssh2_cipher_setiv_wrapper(ssh2_cipher *c, ptrlen key) { if (key.len != ssh2_cipher_alg(c)->blksize) fatal_error("ssh2_cipher_setiv: needs exactly %d bytes", ssh2_cipher_alg(c)->blksize); ssh2_cipher_setiv(c, key.ptr); } #define WRAP_ssh2_cipher_setkey , void ssh2_cipher_setkey_wrapper(ssh2_cipher *c, ptrlen key) { if (key.len != ssh2_cipher_alg(c)->padded_keybytes) fatal_error("ssh2_cipher_setkey: needs exactly %d bytes", ssh2_cipher_alg(c)->padded_keybytes); ssh2_cipher_setkey(c, key.ptr); } #define WRAP_ssh2_cipher_encrypt , strbuf *ssh2_cipher_encrypt_wrapper(ssh2_cipher *c, ptrlen input) { if (input.len % ssh2_cipher_alg(c)->blksize) fatal_error("ssh2_cipher_encrypt: needs a multiple of %d bytes", ssh2_cipher_alg(c)->blksize); strbuf *sb = strbuf_new(); put_datapl(sb, input); ssh2_cipher_encrypt(c, sb->u, sb->len); return sb; } #define WRAP_ssh2_cipher_decrypt , strbuf *ssh2_cipher_decrypt_wrapper(ssh2_cipher *c, ptrlen input) { if (input.len % ssh2_cipher_alg(c)->blksize) fatal_error("ssh2_cipher_decrypt: needs a multiple of %d bytes", ssh2_cipher_alg(c)->blksize); strbuf *sb = strbuf_new(); put_datapl(sb, input); ssh2_cipher_decrypt(c, sb->u, sb->len); return sb; } #define WRAP_ssh2_cipher_encrypt_length , strbuf *ssh2_cipher_encrypt_length_wrapper(ssh2_cipher *c, ptrlen input, unsigned long seq) { if (input.len != 4) fatal_error("ssh2_cipher_encrypt_length: needs exactly 4 bytes"); strbuf *sb = strbuf_new(); put_datapl(sb, input); ssh2_cipher_encrypt_length(c, sb->u, sb->len, seq); return sb; } #define WRAP_ssh2_cipher_decrypt_length , strbuf *ssh2_cipher_decrypt_length_wrapper(ssh2_cipher *c, ptrlen input, unsigned long seq) { if (input.len % ssh2_cipher_alg(c)->blksize) fatal_error("ssh2_cipher_decrypt_length: needs exactly 4 bytes"); strbuf *sb = strbuf_new(); put_datapl(sb, input); ssh2_cipher_decrypt_length(c, sb->u, sb->len, seq); return sb; } #define WRAP_ssh2_mac_genresult , strbuf *ssh2_mac_genresult_wrapper(ssh2_mac *m) { strbuf *sb = strbuf_new(); void *u = strbuf_append(sb, ssh2_mac_alg(m)->len); ssh2_mac_genresult(m, u); return sb; } #define WRAP_dh_validate_f , bool dh_validate_f_wrapper(dh_ctx *dh, mp_int *f) { return dh_validate_f(dh, f) == NULL; } void ssh_hash_update(ssh_hash *h, ptrlen pl) { put_datapl(h, pl); } void ssh2_mac_update(ssh2_mac *m, ptrlen pl) { put_datapl(m, pl); } static RSAKey *rsa_new(void) { RSAKey *rsa = snew(RSAKey); memset(rsa, 0, sizeof(RSAKey)); return rsa; } #define WRAP_rsa_ssh1_encrypt , strbuf *rsa_ssh1_encrypt_wrapper(ptrlen input, RSAKey *key) { /* Fold the boolean return value in C into the string return value * for this purpose, by returning the empty string on failure */ strbuf *sb = strbuf_new(); put_datapl(sb, input); if (!rsa_ssh1_encrypt(sb->u, sb->len, key)) sb->len = 0; return sb; } #define WRAP_rsa_ssh1_decrypt_pkcs1 , strbuf *rsa_ssh1_decrypt_pkcs1_wrapper(mp_int *input, RSAKey *key) { /* Again, return "" on failure */ strbuf *sb = strbuf_new(); if (!rsa_ssh1_decrypt_pkcs1(input, key, sb)) sb->len = 0; return sb; } #define return_void(out, expression) (expression) #define VALTYPE_TYPEDEF(n,t,f) \ typedef t TD_val_##n; \ typedef t *TD_out_val_##n; VALUE_TYPES(VALTYPE_TYPEDEF) #undef VALTYPE_TYPEDEF #define OPTIONAL_PTR_FUNC(type) \ typedef TD_val_##type TD_opt_val_##type; \ static TD_opt_val_##type get_opt_val_##type(BinarySource *in) { \ ptrlen word = get_word(in); \ if (ptrlen_eq_string(word, "NULL")) \ return NULL; \ return unwrap_value_##type(lookup_value(word))->vu_##type; \ } OPTIONAL_PTR_FUNC(ssh2_cipher) OPTIONAL_PTR_FUNC(mpint) typedef uintmax_t TD_uint; typedef ptrlen TD_val_string_ptrlen; typedef char *TD_val_string_asciz; typedef BinarySource *TD_val_string_binarysource; typedef unsigned *TD_out_uint; typedef BinarySink *TD_out_val_string_binarysink; typedef ssh_hash *TD_consumed_val_hash; typedef const ssh_hashalg *TD_hashalg; typedef const ssh2_macalg *TD_macalg; typedef const ssh_keyalg *TD_keyalg; typedef const ssh1_cipheralg *TD_ssh1_cipheralg; typedef const ssh2_cipheralg *TD_ssh2_cipheralg; typedef const struct ssh_kex *TD_dh_group; typedef const struct ssh_kex *TD_ecdh_alg; typedef RsaSsh1Order TD_rsaorder; #define WRAPPED_NAME_INNER(a, b, ...) b #define WRAPPED_NAME_OUTER(...) WRAPPED_NAME_INNER(__VA_ARGS__) #define WRAPPED_NAME(func) WRAPPED_NAME_OUTER(WRAP_##func func##_wrapper, func) #define FUNC0(rettype, function) \ static void handle_##function(BinarySource *in, strbuf *out) { \ return_##rettype(out, WRAPPED_NAME(function)()); \ } #define FUNC1(rettype, function, type1) \ static void handle_##function(BinarySource *in, strbuf *out) { \ TD_##type1 arg1 = get_##type1(in); \ return_##rettype(out, WRAPPED_NAME(function)(arg1)); \ } #define FUNC2(rettype, function, type1, type2) \ static void handle_##function(BinarySource *in, strbuf *out) { \ TD_##type1 arg1 = get_##type1(in); \ TD_##type2 arg2 = get_##type2(in); \ return_##rettype(out, WRAPPED_NAME(function)(arg1, arg2)); \ } #define FUNC3(rettype, function, type1, type2, type3) \ static void handle_##function(BinarySource *in, strbuf *out) { \ TD_##type1 arg1 = get_##type1(in); \ TD_##type2 arg2 = get_##type2(in); \ TD_##type3 arg3 = get_##type3(in); \ return_##rettype(out, WRAPPED_NAME(function)(arg1, arg2, arg3)); \ } #define FUNC4(rettype, function, type1, type2, type3, type4) \ static void handle_##function(BinarySource *in, strbuf *out) { \ TD_##type1 arg1 = get_##type1(in); \ TD_##type2 arg2 = get_##type2(in); \ TD_##type3 arg3 = get_##type3(in); \ TD_##type4 arg4 = get_##type4(in); \ return_##rettype(out, WRAPPED_NAME(function)(arg1, arg2, arg3, arg4)); \ } #define FUNC_SELECT_OUTER(...) \ FUNC_SELECT_INNER(__VA_ARGS__,FUNC4,FUNC3,FUNC2,FUNC1,FUNC0)(__VA_ARGS__) #define FUNC_SELECT_INNER(r,f,a1,a2,a3,a4,m,...) m #define FUNC FUNC_SELECT_OUTER #include "testcrypt.h" #undef FUNC static void process_line(BinarySource *in, strbuf *out) { ptrlen id = get_word(in); #define DISPATCH_COMMAND(cmd) \ if (ptrlen_eq_string(id, #cmd)) { \ handle_##cmd(in, out); \ return; \ } DISPATCH_COMMAND(hello); DISPATCH_COMMAND(free); DISPATCH_COMMAND(newstring); DISPATCH_COMMAND(getstring); DISPATCH_COMMAND(mp_literal); DISPATCH_COMMAND(mp_dump); #define FUNC(rettype, function, ...) DISPATCH_COMMAND(function); #include "testcrypt.h" #undef FUNC fatal_error("command '%.*s': unrecognised", PTRLEN_PRINTF(id)); } static void free_all_values(void) { for (Value *val; (val = delpos234(values, 0)) != NULL ;) free_value(val); freetree234(values); } int main(int argc, char **argv) { FILE *fp = stdin; if (argc > 1) { fp = fopen(argv[1], "r"); if (!fp) { fprintf(stderr, "%s: open: %s\n", argv[1], strerror(errno)); return 1; } } values = newtree234(valuecmp); atexit(free_all_values); for (char *line; (line = chomp(fgetline(fp))) != NULL ;) { BinarySource src[1]; BinarySource_BARE_INIT(src, line, strlen(line)); strbuf *sb = strbuf_new(); process_line(src, sb); run_finalisers(sb); size_t lines = 0; for (size_t i = 0; i < sb->len; i++) if (sb->s[i] == '\n') lines++; printf("%zu\n%s", lines, sb->s); fflush(stdout); strbuf_free(sb); sfree(line); } if (fp != stdin) fclose(fp); return 0; }