[RFC] Type-Directed Relay Fuzzing Library

Good questions.

Regarding a unified fuzzing infrastructure, I would have to have a more specific idea of what the TIR and Relax fuzzers would look like. The fuzzer I’ve described here is specialized to Relay’s type system and the special considerations it leads to. There are some high-level aspects we could easily share across languages like the formats we use for configuring generators, expressing constraints for oracles, etc. It’s harder to judge what might be similar about program generation beyond the broadest APIs. I would be happy to discuss deduplication of effort with TVMFuzz with @ganler.

You are correct about not being able to reuse the type inference functions, which does result in some additional work for supporting operators in the fuzzer. I can only really imagine one way we can effectively reuse the type inference functions, which is to attempt to sample input types and use the type relation to produce the output type. Such an approach would either be inefficient (if we try to randomly sample types) or require manually specifying constraints on the input types and hence would require additional work from users anyway (this is what inspired the samplers in the current proposals). While manually registering recognizers, solvers, and samplers for existing operators may be slightly inconvenient (technically you can manage with only a sampler or only a recognizer and solver, but having all of them is more helpful), these did not prove to be very hard to implement and allow for a great degree of control over the generation policy. It’s also very easy to test the recognizers, solvers, and samplers against the existing type relation (see the unit tests in the POC).

I think it would make the most sense that if a new op is added, the developer should also add the basic callbacks to the solver. This would not be a large requirement and it has an immediate benefit: The developer can immediately use these properties to start fuzz-testing the new operator. So even though it is more work for the developer, there is also an immediate positive result from doing it.

I agree that it would be ideal if there were some way to use only the type relation to generate solutions to type relations, but it is not apparent to me how it could be done efficiently.

1 Like

Thanks @slyubomirsky @yuchenj for cc’ing me! Fuzzing or automated testing is definitely a very promising and affordable direction to improve the software quality/safety while not wasting developers’ precious time to come up with complete test cases.

I have not completely gone through this pre-RFC but I will be reviewing it little by litte carefully and try to share something that I know. To get started, I’d like to share a bit about the related work (as I am specifically working in this research area).

The core technique of a language fuzzer is how to generate a random but still well-formed program.

In our case, we generate relay IR instead of TIR (by TVMFuzz or Tzer) or other similar C-like language (by CSmith/graphicsfuzz/clangfuzz/etc.). Therefore, I don’t think there is any “duplication” in Steven’s proposal. Meanwhile, TVMFuzz and Tzer are more like a “prototype”. Building a reliable and stable fuzzer for a specific system (the best example in my mind is Syzkaller for OS) still requires better effort. But of course, once the fuzzer is built, the effort of bug detection can be converted from manual time to machine time!

That said, Relay is a graph-level or high-level language which is different from those prior work targetting C-like low-level languages, and building a Relay fuzzer, or more fundamentally, building a generator of Relay programs, requires adopting specifc assumptions in the language (types, semantics, etc.), motivating the necessity.

I think a more closely related and better engineered tool should be NNSmith (sorry for the advertisement but I think it is really related), where we generate random DNNs (static graph tho) and translate it to PyTorch, ONNX and TensorFlow to test themselves and downstream compilers/engines (and we found 34 bugs in TVM during our prototyping stage, see ASPLOS'23-NNSmith-Bugs - Google Sheets). However, NNSmith models DNNs in DAGs which is a subset of representations that Relay is capable of. Therefore, it is still important to directly perform model generation using Relay if we want to find bugs in say if statements and while loops. Another project that I have noticed is GenCoG which generates Relay programs and the project maintainer has been submitting some bug reports to TVM before.

(I will give comments regarding the details of the pre-RFC in another reply).

2 Likes

Thank you for the pointers. I had not previously been aware of GenCoG and will have to acquaint myself with it. I’ll study the interfaces of the prior tools and see whether I can improve the fuzzer I’ve implemented here using them (and whether GenCoG’s approach to dealing with type constraints is better).

1 Like

The RFC looks very cool! I like the formulation of type-directed fuzzing since Relay IR’s representation is a superset to “DAGs”.

I have a few questions or comments:

  1. Is the operator constraints/type relation of operators implemented in a unified and approachable way that can be directly extracted by the fuzzer? In NNSmith, we hoped to not be coupled with any specific implementation of DL systems and wanted to keep everything in Python so we manually implemented such constraints from scratch. More specifically, we modeled “abstract operator” and “abstract tensors” (which are basically tensor types in the current context) and the rules need to be given by humans (see Listing 2 for an example) or the code.

Overall, I think it would be very beneficial to unify the way we write concrete constraints in type checking and make the information of type constraints extensible and approachable for the fuzzer.

  1. In addition to the validity constraints, there are some “practical” constraints to “stabilize” the fuzzer. For example, it would be great to have some computation constraints to limit the FLOPS or memory consumption of the generated program, otherwise, it might lead to a slowdown and OOM.

  2. It is great to see the “sampler” idea as actually I was building something similar. Meanwhile, in my experience, the overhead of the SMT solver is often negligible compared with TVM’s compilation. For example, the current NNSmith always uses SMT solving but still generates a 10-node ONNX model within 100ms but the compilation could take over 10 seconds (but it might be due to the bottleneck of TVM’s ONNX importer). But this idea can be very useful for other DL frameworks with fast/lightweight compilation.

  3. The idea to make it a fuzzing library is amazing. This helps developers to reuse the generator to easily access inputs in desired patterns when writing tests.

  4. Regarding the coverage, I built a tool called “memcov” to compile TVM or any other Python-C++ systems with runtime coverage information accessible in Python and has been used for TVM in my projects for a long time. With this, it is easy to write coverage-guided fuzzing loop in Python or just access uncompressed or compressed CFG coverage). The patch for TVM can be extremely simple as well!

  5. Regarding the tradeoff of using C++ or Python: I think often time the bottleneck is not in generation (at least for most cases in TVM). And even if we say we only use it to test a few passes instead of the whole pipeline, the bottleneck IMO is in the solver so using C++ might not make much difference in performance. That said, I would suggest Python for making the tool more approachable if there is no functional limitation.

(still not fully finished with this thread. but will be sharing more comments later little by little)

1 Like

@ganler Thanks for inviting me to this discussion. @slyubomirsky Your RFC is very nice and your efforts in developing a Relay fuzzer is appreciated. As is previously introduced by @ganler, I have implemented GenCoG, a computation graph generator for Relay. Since you are familiar with NNSmith, I may mainly discuss the similarities and differences between GenCoG and NNSmith here. @ganler Please point out if I get anything wrong about your work here.

Similarities:

  1. Both GenCoG and NNSmith generate pure dataflow graphs without control flow.
  2. Both methods require developers to manually specify constraints of operators and leverage a constraint solver to find solutions.
  3. Both methods incrementally construct a computation graph.

Differences (actually there are many, but I may only discuss the major ones):

  1. The most prominent difference is how developers write specifications of operators. GenCoG provides a simple DSL for specifying the constraints in a more organized, concise and comprehensible way. Let’s take reshape operator as an example. In GenCoG, we may specify its constraints as follows:
TypeSpec(
  attrs={
    Attr('newshape', List(Var(ran=dl_rank_ran), 
                          lambda _: Var(int, ran=dim_ran, tmpl=True)))
  },
  in_num=1,
  in_ranks=[Var()],
  in_shapes=[List(IN[0].rank, lambda _: Var(tmpl=True))],
  in_dtypes=[Var()],
  extra=[
    Reduce(IN[0].shape, ArithOp.MUL, 1) == Reduce(a('newshape'), ArithOp.MUL, 1)
  ],
  out_num=1,
  out_ranks=[Len(a('newshape'))],
  out_shapes=[a('newshape')],
  out_dtypes=[IN[0].dtype]
)

The solver in GenCoG automatically processes this specification and solves constraints involved. For NNSmith, the specification of reshape is much longer:

@mark_materialize("core")
class Reshape(UnaryOpBase):
    num_var_param = int_range(1, 4)
    in_dtypes = [(i,) for i in DTYPE_ALL]
    out_dtypes = [(i,) for i in DTYPE_ALL]

    def __init__(self, *target_shape):
        super().__init__()
        self.inp_ranks = [int_range(1, 4)]
        self.out_ranks = [(len(target_shape),)]
        self.target_shape: List[Union[int, z3.ExprRef]] = list(target_shape)

    def type_transfer(self, input_shapes: List[AbsTensor]) -> List[AbsTensor]:
        __MAX_SOLVE_SYMBOL__ = 8
        # otherwise OOM.
        ConstraintCheck.le(
            input_shapes[0].ndims + len(self.target_shape), __MAX_SOLVE_SYMBOL__
        )

        if -1 not in self.target_shape:
            return [AbsTensor(self.target_shape, dtype=input_shapes[0].dtype)]
        # else
        abs_tensor = AbsTensor(self.target_shape, dtype=input_shapes[0].dtype)
        auto_dim = -1
        accum = 1
        for i, v in enumerate(self.target_shape):
            # TODO: What to do about bitvectors here?
            if v == -1:
                SanityCheck.eq(auto_dim, -1)
                auto_dim = i
            else:
                accum = nnsmith_mul(accum, v)

        abs_tensor.shape[auto_dim] = nnsmith_div(
            reduce(lambda x, y: nnsmith_mul(x, y), input_shapes[0].shape, 1), accum
        )

        return [abs_tensor]

    def requires(self, input_shapes):
        ret = []

        inp = input_shapes[0]
        src_len, dst_len = inp.ndims, len(self.target_shape)
        if src_len == 0:
            src_len = 1  # special handling for scalar
        if dst_len == 0:
            dst_len = 1  # special handling for scalar
        gres_config = os.getenv("NNSMITH_GRES", "4")
        if gres_config == "5":
            ng = 1
        elif gres_config == "3":
            ng = min(src_len, dst_len)
        elif gres_config == "4":
            ub = min(src_len, dst_len)
            ng = random.choices(
                range(1, ub + 1), k=1, weights=[2**i for i in range(ub)]
            )[0]
        else:
            raise ValueError(f"NNSMITH_GRES={gres_config} is not recognized")
        src_group = random_group(src_len, ng)
        dst_group = random_group(dst_len, ng)
        self.ng = ng
        self.src_group = src_group
        self.dst_group = dst_group
        assert len(src_group) == len(dst_group) == ng, (src_group, dst_group)

        # group constraints
        src_vars = inp.shape
        dst_vars = self.target_shape
        if len(src_vars) == 0:
            src_vars = [1]  # special handling for scalar
        if len(dst_vars) == 0:
            dst_vars = [1]  # special handling for scalar
        cons_group = []
        for gid in range(ng):
            src_idx = src_group[gid]
            dst_idx = dst_group[gid]
            src_prod = reduce(nnsmith_mul, [src_vars[i] for i in src_idx], 1)
            dst_prod = reduce(nnsmith_mul, [dst_vars[i] for i in dst_idx], 1)
            cons_group.append(nnsmith_eq(src_prod, dst_prod))

        ret.extend(cons_group)
        if os.getenv("NNSMITH_CONS_RESHAPE", "off") != "off":
            # should not be too extreme!
            __DIM_LIMIT__ = 4096
            lim = __DIM_LIMIT__
            for s in self.target_shape[::-1]:
                ret.append(nnsmith_le(s, lim))
                lim //= 2
                lim = max(lim, 1)
        assert -1 not in self.target_shape
        return ret

    def deduct_inp_ranks_and_dtype(
        self, out_abs_tensor: List[AbsTensor]
    ) -> List[Tuple[int, DType]]:
        return [(-1, out_abs_tensor[0].dtype)]

Broadly speaking, the specification in GenCoG is more declarative while the one in NNSmith is more imperative (e.g., some manual sampling code is involved). I am aware that the specification in NNSmith contains more fine-grained handling of some corner cases, and therefore it may not be directly comparable with the one in GenCoG. However, GenCoG is still able to potentially make the specification simpler.

  1. GenCoG does not contain the value searching technique in NNSmith, which improves numerical validity. It is still possible that GenCoG can be extended with this part.

That is what I have to share for now. I may add some more details and discuss other related issues later.

2 Likes

Thanks for the input. For the “the specification of reshape is much longer”, just want to quickly explain that the validity spec for reshape in NNSmith can be as short as one liner like return product(*self.target_shape) == product(*input_shape). It is longer, not because we handled more corner cases or the way we describe specification is not declaritive, but because we did some non-trivial optimization to constraint the solve space to make SMT solving affordable. This is because the complexity of the MIP problem of a * b * c == d * f * g is too high. If you read the code more carefully, we did constraint sub-grouping (credit to my collaborator Jinkun) to convert the original high-complexity constraint to a low-complexity one, trading off a bit solving space for efficiency. I will read other comments later. Thanks for the input again!

1 Like

Many thanks. I did not completely understand what the code in requires is doing in the first place and made incorrect judgement in my comment. I apologize for that. Now I know and thanks again for your kind explanation.

There is one thing I am considering. Since the specification should be provided by human, I guess it should be as simple as possible. Maybe the developer only needs to write product(*self.target_shape) == product(*input_shape) in the specification and rest constraint solving implementation (including the optimizations like constraint sub-grouping you mentioned) can be done by the solver.

Then I may talk about my own experience. In the implementation of GenCoG, there is no such optimization that improves the solving efficiency. The solver directly feeds these non-trivial constraints to the SMT solver. However, the graph generation time of GenCoG is not so long. It takes less than one second to generate a 32-node graph. (It is not significantly slower than NNSmith, right?) In my opinion, though the general complexity of the MIP problem is high, for the cases that we encounter in operator generation, their actual complexity may still be acceptable.

I am not questioning the quality of your work. Definitely you did an excellent job in DL compiler fuzzing, and that’s why it is widely recognized by the community. I just want to discuss what may happen if we make different decisions in design and implementation of a graph-level fuzzer. Please feel free to point out my mistakes and I am looking forward to our further discussion.

Thanks for the discussion! I agree in most cases SMT solving are light-weight, given the majority of operators have simple constraints (e.g., element-wise or injective). Such optimization should only be considered in very rare cases. Taking the example of reshape again, solving one single reshape operator could bring up to 8 second solving time subjective to the number of symbols and prior constraints, which is why sometimes it is meanful to do some optimization.

Examples [click to expand]

But it could be specific to z3. There are other alternative solvers such as cvc5 and mip which might solve those constraints faster. BTW, we can talk more about the prior work elsewhere and let’s keep this thread more focused on RFC’s content. :slight_smile:

2 Likes

@ganler and @wzh99, I thank you both very much for sharing some insights about your explorations in fuzzing TVM. I am glad to hear that in your experience, SMT solvers have not proven a large burden to use. I preferred to avoid one here to avoid introducing another external dependency to TVM, but perhaps the convenience of being able to have a compact specification could be worth it in the long run.

Unfortunately (in response to @ganler’s question above), the type relations used in Relay’s compiler are opaque functions in C++ and so cannot easily be exported into SMT queries, so any approach will necessitate manually reimplementing the type constraints, either as solver queries or using the functions I defined here to directly generate solutions. It might be worth considering if reimplementing Relay’s type relations using a solver could be worth doing (I wondered about this in my TVMCon talk), but that would be a large project in its own right.

In addition to the validity constraints, there are some “practical” constraints to “stabilize” the fuzzer. For example, it would be great to have some computation constraints to limit the FLOPS or memory consumption of the generated program, otherwise, it might lead to a slowdown and OOM.

Implementing such restrictions would be a very good idea, though I am not sure about the best way to do it. I guess there could be counters approximating the performance of operators and some way to short-circuit generation if they prove too expensive. This also reminds me that I need to implement some kind of timeout in the execution phase because there is nothing to stop the fuzzer from generating an infinite loop…

  1. Regarding the coverage, I built a tool called “memcov” to compile TVM or any other Python-C++ systems with runtime coverage information accessible in Python and has been used for TVM in my projects for a long time. With this, it is easy to write coverage-guided fuzzing loop in Python or just access uncompressed or compressed CFG coverage). The patch for TVM can be extremely simple as well!

  2. Regarding the tradeoff of using C++ or Python: I think often time the bottleneck is not in generation (at least for most cases in TVM). And even if we say we only use it to test a few passes instead of the whole pipeline, the bottleneck IMO is in the solver so using C++ might not make much difference in performance. That said, I would suggest Python for making the tool more approachable if there is no functional limitation.

Thank you also for these comments. The subject of tracking our test coverage has come up outside the context of fuzzing, so this might be good to adopt in TVM’s CI as well (@driazati, perhaps you might be interested?). I also appreciate the insight that generation time may not necessarily be a huge impediment; it was certainly easy to assemble this prototype.

The approximation for the memory side could be estimated by calculating the sum of tensor sizes according to their data types (to be more precise we can do some maximum liveness analysis to get the expected peak memory usage). Or a simpler way can be limiting the dimension sizes to a reasonable number. A timeout aborting can be implemented by process forking. However, timeout could be a bug theoretically that in the future maybe we can put some constraints to generate a program that must terminate reasonably.

Yeah, coverage can be an important metric for developers to understand which part of the compiler is not properly verified, which is beyond fuzzing.

Just want to mention that there is actually some tiny differences in coverages in terms of implementation/use cases. To get readable (i.e., source code lelve) coverage reports for the developers, we can use LLVM-COV or GCOV and it can smoothly work with TVM by adding a few compile flags. However, in fuzzing, source code coverage are not used, like AFL, LibFuzzer and Tzer (i.e., memcov) all implement a CFG-level (LLVM IR level) coverage with LLVM CovSan which is more efficient and can be applied with many optimizations (say hashing the coverage to reduce memory footprint). I also had a blog article related to coverage usages in C++ (if it helps).

2 Likes

We discussed a few aspects of this at the community meeting this morning:

  • When should we generate programs to test? When should the tests be run?
    • @slyubomirsky Other fuzzers have many different approaches, some generate and save programs and others generate them each time. Maybe we should do both in order to test the fuzzer
    • @driazati we should get the fuzzer running in CI at some scale now, we can grow or change it as we try it out
  • As many (potentially thousands) of generated programs hit the same kinds of bugs, how do we group stack traces to simplify error reporting / bug hunting?
    • @slyubomirsky Fuzzing frameworks also differ here, some look at parts of the stack trace
    • @gromero Maybe it’s possible to look at program features to identify what triggers similar bugs
  • How do we save generated programs without relying on the Relay text format?
    • @slyubomirsky pickleing the IRModule works out of the box but is pretty opaque
    • @driazati pickle might be ok, we can probably disable the arbitrary code execution parts of it (another note: pickle also has a text format that might be easier to stomach, but there is no parser for it afaik)
    • @sebastianboblestetas We should do the JSON format if the serializer is simple enough
2 Likes

I’ve attempted to implement a JSON serialization format, but I’m facing some difficulties when trying to serialize the Relay Prelude and get it back. I’m not sure what might be going on, but the main point of the matter is that preserving reference equality where Relay expects it is fairly tricky, so it would take a fair bit of validation to make sure that this serialization is correct. (I’m still not completely sure of what is going wrong here, incidentally.) The implementation was also fairly large (~700 lines), though much of that was boilerplate. Fixed, see below edits… I’ll post the serializer when I have tests ready and we can decide if that’s preferable to pickle.

If we are very intent on having this JSON format rather than pickle, would anyone be interested in discussing how it should be implemented or trying to figure out what issues are arising with the Relay Prelude?

Edit: @driazati, do you think we might be able to cryptographically sign off on our Pickle files somehow to ensure potentially malicious ones can’t be inserted? I’d be willing to explore other options related to pickle, as it’s seeming like developing another serializer will be a nontrivial task.

Edit 2: I got a new idea for a simpler approach to JSON serialization (just cache all types and exprs by their pointer value and use that to determine reference equality) and I’ll see if that works instead.

Edit 3: The simpler approach is still encountering the same odd problem when trying to serialize the Prelude. I have no clue why this shouldn’t work, since it is tracking reference equality for every part of the AST. Happy to provide code for debugging this, but it’s strange and makes me think the pickle option is the better one

Edit 4: I think I may have found the issue. It turns out the contents of attrs could be arbitrary TVM objects…

I’ve implemented a JSON serializer if we think that’s preferable to using pickle. The implementation is about 700 lines, much of which is boilerplate, but it’s not very complicated other than having to deal with numpy values (in constants) and attributes nodes. See the implementation here: https://github.com/slyubomirsky/tvm/blob/fuzzer-poc/python/tvm/relay/testing/ast_dict.py

Tests here (I’ve run a lot of fuzzing trials on it and it’s worked, so I’m confident it works): https://github.com/slyubomirsky/tvm/blob/fuzzer-poc/tests/python/relay/fuzzing/test_ast_dict.py

It seems to be a bit slower than pickle, but it could definitely be implemented in C++ and thus be a lot faster.

CC: @driazati and @sebastianboblestetas (since we had discussed it in the community meeting)

1 Like

LGTM on a first glance, maybe a small comment: You often use the construct set(map(lambda …) like here:

Couldn’t you simply use {gv.name_hint for gv in mod.get_global_vars()} ?

Thank you for taking a look. Yeah I think you’re right that there’s no advantage to using map in such cases.

I suppose it would make sense to present the JSON serializer as one option in the eventual RFC.

Given that it’s implemented and (I’m assuming) fast enough, is there any reason to use pickle/cpickle at this point?

IMO, the main reasons to use pickle over the JSON serialization format besides speed are also that the TVM community would not be on the hook for maintaining pickle and that pickle will also almost certainly be futureproof with respect to any changes in Relay.

1 Like

Note for future reference: this paper contains a survey of stack trace deduplication methods. The most effective rely on some learned components, which we wouldn’t be able to train until we’ve already done some fuzzing, so I will initially try some of the simpler edit distance–based approaches.

I’m pleased to note that I’ve added a simple script for clustering stack traces, allowing for duplicate stack traces to be easily grouped together and thus making it easier to start analyzing the bugs that have emerged. The script is in tests/scripts/relay_fuzzing/cluster_traces.py in the fuzzer proof-of-concept fork. (Note: It requires the Levenshtein library and SK-Learn.)

With this tool, I think it would be fair to say that the “fuzzing pipeline” here is complete: With the code in my proof-of-concept, it is possible to generate programs, run tests for different conditions on them, and analyze the resulting stack traces to start debugging. With these changes, I think it would be useful to start discussing how workable this infrastructure would be and perhaps to start moving on to a formal RFC.

In my small-scale tests, I found it effective to use a very simple form of clustering: Just agglomerative (hierarchical) clustering based on the Levenshtein distance (i.e., edit distance) between the string dumps of the stack traces. The script simply uses a similarity percentage (Levenshtein distance divided by the length of the longer of the two strings) with a 10% similarity cutoff. It’s a very simple method and requires no training.

In the future, we can explore using one of the more advanced methods for clustering stack traces if this proves inadequate in practice, but from my cursory examination, the resulting clusters made sense (different kinds of crashes were indeed in different clusters).

Special thanks to @ehsanmok for giving me some advice on this subject.

1 Like