Using decorators for common hardware structures

This is my first non-trivial myhdl program. I made two decorators to automate the myhdl evaluation step and reduce code duplication. The first decorator, @chain, connects hardware blocks serially with one input and one output per block. A simple use for this is in a vga timer. The vga timer takes a clk and enable in and outputs the vga timing signals along with an x and y position which is used later for deciding what color that pixel is. The vga timer also outputs h_refresh and v_refresh, two one clk wide pulses that are sent out when the horizontal and vertical counters, respectively, overflow. First I wrote something like this:

def vga_timer(clk, en, x, y, h_refresh, v_refresh, h_sync, v_sync, video_on):
    line_width =  # differs by resolution
    num_lines =  # differs by resolution
    # h and v count when the screen is displaying as well as during the blanking period
    h = Signal(intbv(0, min=0, max=line_width))
    v = Signal(intbv(0, min=0, max=num_lines))
    def count_hv():
        if en:
   = 0
   = 1
   = h + 1
            if h == line_width - 1:
       = 0
       = 1
       = v + 1
                if == num_lines - 1:
           = 0
           = 1
    def output_xy(): = h[len(x):] = h[len(y):]
    # logic for h_sync, v_sync, and video_on left out because it's irrelevant to this post

This works just fine for two counters that chain into each other, but if for some reason 5 or more counters were needed then there would be a lot of code duplication. For this reason I wrote a counter block that takes a clk, enable, and count signal in and outputs a pulse when count overflows.

def counter(clk, count, max_value, en=True, pulse=None):

Now count_hv can be replaced with instances of counter

h_counter = counter(clk, h, line_width, en=en, pulse=h_refresh)
v_counter = counter(clk, h, num_lines, en=h_refresh, pulse=v_refresh)

This solution also allows the use of n counters with only a list of signals for each input or output of counter. I have found a list of chained instances to be a very common pattern. It is annoying to have to type the same pattern for a chain of small blocks especially when only the global input and output matter.

in_sigs = [global_in] + [Signal(False) for _ in range(1, num_sigs)]
in_sigs = [Signal(False) for _ in range(1, num_sigs)] + [global_out]
small_blocks = [small_block(insigs[i], out_sigs[i]) for i in range(1, num_sigs)]

Python is great at removing code duplication, thankfully. My solution for the problem is the decorator @chain. Here’s how @chain is used in the vga timer.

@chain('en', 'pulse')
def counter(clk, count, max_value, en=True, pulse=None):

counter_insts = counter(clk, [h, v], [line_width, num_lines], en=en, pulse=[h_refresh, v_refresh], chain_length=2)

Lone arguments like clk are given to every instance of counter. Any list inputs like count are distributes so each counter instance gets count[i]. List arguments must have chain_length elements. The @chain decorator takes two positional inputs chain_in and chain_out (‘en’ and ‘pulse’ in this case). The outputs are connected inside the decorator to the previous inputs. If the chain_in or chain_out arguments aren’t lists or any of the elements of the lists are None then signals of type chain_type. chain_type is the final argument to chain and is a named input which defaults to False, so it wasn’t included in this example.

In the future I would like to make chain_length be implied (unless there are no list arguments) and allow the argument for chain_out (pulse in this case) be None. I’m also considering allowing custom connection logic so that multiple chain_in’s and chain_out’s can be set. Here’s the full code for @chain which still needs to be cleaned up and refactored.

def chain(chain_in, chain_out, chain_type=False):
    def chain_decorator(func):
        def chain_wrapper(*args, chain_length=1, **kwargs):
            assert chain_length > 0
            if chain_length == 1:
                return func(*args, **kwargs)

            lone_input = True
            lone_output = True
            c_in = kwargs.get(chain_in)
            c_out = kwargs.get(chain_out)
            chain_in_sigs = []
            if isinstance(c_in, list):
                lone_input = False
                chain_in_sigs = kwargs.pop(chain_in)
                assert len(chain_in_sigs) == chain_length, 'list arguments must match the chain length'
                for i in range(chain_length):
                    if chain_in_sigs[i] is None:
                        chain_in_sigs[i] = Signal(chain_type)
                chain_in_sigs = [Signal(chain_type) for _ in range(chain_length)]

            chain_out_sigs = []
            # chain_in_sigs and chain_out_sigs can be input as either 1 signal, the input of output of the entire
            # system, or as a list of signals. any missing or None signals are added
            if isinstance(c_out, list):
                lone_output = False
                chain_out_sigs = kwargs.pop(chain_out)
                assert len(chain_out_sigs) == chain_length, 'list arguments must match the chain length'
                for i in range(chain_length):
                    if chain_out_sigs[i] is None:
                        chain_out_sigs[i] = Signal(chain_type)
                chain_out_sigs = [Signal(chain_type) for _ in range(chain_length)]
            funcs = []
            for i in range(chain_length):
                arg_sigs = []
                for arg in args:
                    if isinstance(arg, list):
                        assert len(arg) == chain_length, 'list arguments must match the chain length'
                kwarg_sigs = {}
                for key, arg in kwargs.items():
                    if isinstance(arg, list):
                        assert len(arg) == chain_length, 'list arguments must match the chain length'
                        kwarg_sigs[key] = arg[i]
                        kwarg_sigs[key] = arg
                kwarg_sigs[chain_in] = chain_in_sigs[i]
                kwarg_sigs[chain_out] = chain_out_sigs[i]
                funcs.append(func(*tuple(arg_sigs), **kwarg_sigs))

            def connect():
                for j in range(1, chain_length):
                    chain_in_sigs[j].next = chain_out_sigs[j-1]

            # myhdl can't detect when a signal is renamed by being added to a list of signals so this compensates
            if lone_input:
                if lone_output:
                    def foo():
                        chain_in_sigs[0].next = c_in
               = chain_out_sigs[-1]
                    return funcs, connect, foo
                def foo():
                    chain_in_sigs[0].next = c_in
                return funcs, connect, foo
            elif lone_output:
                def foo():
           = chain_out_sigs[-1]
                return funcs, connect, foo
            return funcs, connect
        return chain_wrapper
    return chain_decorator

The second decorator @tree is for blocks which have a list input and a single output. Unlike the chain decorator, generating a tree structure out of hardware block instances is non-trivial without a recursive wrapper. A natural example for @tree is an adder. The adder I wrote can take up to four inputs in terms.

def adder(terms, result):
    if len(terms) == 1:
        def add():
   = terms[0]
    elif len(terms) == 2:
        def add():
   = terms[0] + terms[1]
    elif len(terms) == 3:
        def add():
   = terms[0] + terms[1] + terms[2]
        def add():
   = terms[0] + terms[1] + terms[2] + terms[3]
    return add

Every hardware block that is decorated with @tree must have terms as its first positional input and result as its second (actual names don’t matter though). The @tree decorator makes a tree structure with each leaf being a term and each internal node being an instance of adder. For 16 terms it will generate four adders with four of the terms each and a final adder to add together the results of each of the four adders. @tree passes any args and kwargs that aren’t terms and result to every insatnce of adder. Here is the decorated adder

@tree(4, adder_result_width)
def adder(terms, result)

The first input to @tree are num_branches which tells @tree the maximum number of branches per node. The second is get_result_width that takes a function which determines the minimum width of the result of each adder instance. An adder with inputs of width w will have a result width of w + log2(num_terms). Here’s the implementation of adder_result_width.

def adder_result_width(terms):
    from math import log2, ceil

    max_len = 1
    for term in terms:
        max_len = max(max_len, len(term))
    return max_len + ceil(log2(len(terms)))

The full code of @tree:

def _balanced_tree_indexer(num_leaves, num_branches):
    """Makes num_branches buckets and fills the buckets as evenly as possible with num_leaves leaves then returns a 
    set of start and end indices for slicing the list of leaves."""
    floor = num_leaves//num_branches
    widths = [floor]*num_branches
    for i in range(num_leaves % num_branches):
        widths[i] += 1
    branch_indices = []
    cur_index = 0
    for i in range(num_branches):
        branch_indices.append((cur_index, cur_index + widths[i]))
        cur_index += widths[i]
    return branch_indices

def tree(num_branches, get_result_width):
    """Decorates a hardware block that takes a list of signal inputs, terms, and a signal result. This decorator uses 
    a tree structure with num_branches branches.  Any additional args or kwargs will be passed to every hardware 
    block instance. """
    def tree_decorator(func):
        def wrapper(terms, root_result, *args, **kwargs):
            assert len(terms) > 0
            if len(terms) <= num_branches:
                return func(terms, root_result, *args, **kwargs)

            branch_indices = _balanced_tree_indexer(len(terms), num_branches)
            branches = []
            branch_results = []
            for i in range(num_branches):
                branch_terms = terms[branch_indices[i][0]:branch_indices[i][1]]
                result_width = get_result_width(branch_terms)
                branch_result = Signal(intbv(0)[result_width:])
                branches.append(wrapper(branch_terms, branch_result, *args, **kwargs))

            root = wrapper(branch_results, root_result, *args, **kwargs)
            return root, branches

        return wrapper

    return tree_decorator

Suggestions, criticism, and comments welcome.

1 Like