# Copyright cocotb contributors
# Copyright (c) 2014 Potential Ventures Ltd
# Licensed under the Revised BSD License, see LICENSE for details.
# SPDX-License-Identifier: BSD-3-Clause
"""Drivers for Advanced Microcontroller Bus Architecture."""
import array
import collections.abc
import enum
import itertools
from typing import Any, List, Optional, Sequence, Tuple, Union
import cocotb
from cocotb.binary import BinaryValue
from cocotb.handle import SimHandleBase
from cocotb.triggers import ClockCycles, Combine, Lock, ReadOnly, RisingEdge
from cocotb_bus.drivers import BusDriver
class AXIBurst(enum.IntEnum):
FIXED = 0b00
INCR = 0b01
WRAP = 0b10
class AXIxRESP(enum.IntEnum):
OKAY = 0b00
EXOKAY = 0b01
SLVERR = 0b10
DECERR = 0b11
class AXIProtocolError(Exception):
def __init__(self, message: str, xresp: AXIxRESP):
super().__init__(message)
self.xresp = xresp
class AXIReadBurstLengthMismatch(Exception):
pass
[docs]class AXI4Master(BusDriver):
"""AXI4 Master
TODO: Kill all pending transactions if reset is asserted.
"""
_signals = [
"AWVALID", "AWADDR", "AWREADY", "AWID", "AWLEN", "AWSIZE", "AWBURST",
"WVALID", "WREADY", "WDATA", "WSTRB",
"BVALID", "BREADY", "BRESP", "BID",
"ARVALID", "ARADDR", "ARREADY", "ARID", "ARLEN", "ARSIZE", "ARBURST",
"RVALID", "RREADY", "RRESP", "RDATA", "RID", "RLAST"]
_optional_signals = ["AWREGION", "AWLOCK", "AWCACHE", "AWPROT", "AWQOS",
"WLAST",
"ARREGION", "ARLOCK", "ARCACHE", "ARPROT", "ARQOS"]
def __init__(self, entity: SimHandleBase, name: str, clock: SimHandleBase,
**kwargs: Any):
BusDriver.__init__(self, entity, name, clock, **kwargs)
# Drive some sensible defaults (setimmediatevalue to avoid x asserts)
self.bus.AWVALID.setimmediatevalue(0)
self.bus.WVALID.setimmediatevalue(0)
self.bus.ARVALID.setimmediatevalue(0)
self.bus.BREADY.setimmediatevalue(1)
self.bus.RREADY.setimmediatevalue(1)
# Set the default value (0) for the unsupported signals, which
# translate to:
# * Transaction IDs to 0
# * Region identifier to 0
# * Normal (non-exclusive) access
# * Device non-bufferable access
# * Unprivileged, secure data access
# * No QoS
unsupported_signals = [
"AWID", "AWREGION", "AWLOCK", "AWCACHE", "AWPROT", "AWQOS",
"ARID", "ARREGION", "ARLOCK", "ARCACHE", "ARPROT", "ARQOS"]
for signal in unsupported_signals:
try:
getattr(self.bus, signal).setimmediatevalue(0)
except AttributeError:
pass
# Mutex for each channel to prevent contention
self.write_address_busy = Lock(name + "_awbusy")
self.read_address_busy = Lock(name + "_arbusy")
self.write_data_busy = Lock(name + "_wbusy")
self.read_data_busy = Lock(name + "_rbusy")
self.write_response_busy = Lock(name + "_bbusy")
@staticmethod
def _check_length(length: int, burst: AXIBurst) -> None:
"""Check that the provided burst length is valid."""
if length <= 0:
raise ValueError("Burst length must be a positive integer")
if burst is AXIBurst.INCR:
if length > 256:
raise ValueError("Maximum burst length for INCR bursts is 256")
elif burst is AXIBurst.WRAP:
if length not in (1, 2, 4, 8, 16):
raise ValueError("Burst length for WRAP bursts must be 1, 2, "
"4, 8 or 16")
else:
if length > 16:
raise ValueError("Maximum burst length for FIXED bursts is 16")
@staticmethod
def _check_size(size: int, data_bus_width: int) -> None:
"""Check that the provided transfer size is valid."""
if size > data_bus_width:
raise ValueError("Beat size ({} B) is larger than the bus width "
"({} B)".format(size, data_bus_width))
elif size <= 0 or size & (size - 1) != 0:
raise ValueError("Beat size must be a positive power of 2")
@staticmethod
def _check_4kB_boundary_crossing(address: int, burst: AXIBurst, size: int,
length: int) -> None:
"""Check that the provided burst does not cross a 4kB boundary."""
if burst is AXIBurst.INCR:
last_address = address + size * (length - 1)
if address & ~0xfff != last_address & ~0xfff:
raise ValueError(
"INCR burst with start address {:#x} and last address "
"{:#x} crosses the 4kB boundary {:#x}, which is forbidden "
"in a single burst"
.format(address, last_address,
(address & ~0xfff) + 0x1000))
async def _send_write_address(
self, address: int, length: int, burst: AXIBurst, size: int,
delay: int, sync: bool
) -> None:
"""Send the write address, with optional delay (in clocks)"""
async with self.write_address_busy:
if sync:
await RisingEdge(self.clock)
await ClockCycles(self.clock, delay)
# Set the address and, if present on the bus, burst, length and
# size
self.bus.AWADDR <= address
self.bus.AWVALID <= 1
if hasattr(self.bus, "AWBURST"):
self.bus.AWBURST <= burst.value
if hasattr(self.bus, "AWLEN"):
self.bus.AWLEN <= length - 1
if hasattr(self.bus, "AWSIZE"):
self.bus.AWSIZE <= size.bit_length() - 1
# Wait until acknowledged
while True:
await ReadOnly()
if self.bus.AWREADY.value:
break
await RisingEdge(self.clock)
await RisingEdge(self.clock)
self.bus.AWVALID <= 0
async def _send_write_data(
self, address, data: Sequence[int], burst: AXIBurst, size: int,
delay: int, byte_enable: Sequence[Optional[int]], sync: bool
) -> None:
"""Send the write data, with optional delay (in clocks)."""
# Helper function for narrow bursts
def mask_and_shift(value: int, block_size: int, block_num: int) -> int:
return (value & (2**block_size - 1)) << (block_num * block_size)
# [0x33221100, 0x77665544] --> [0x221100XX, 0x66554433]
def unalign_data(
data: Sequence[int], size_bits: int, shift: int
) -> List[int]:
padded_data = (0,) + tuple(value for value in data)
low_mask = 2**(size_bits - shift) - 1
high_mask = (2**shift - 1) << (size_bits - shift)
return [(padded_data[i] & high_mask) >> (size_bits - shift) |
(padded_data[i + 1] & low_mask) << shift
for i in range(len(data))]
strobes = []
byte_enable_iterator = iter(byte_enable)
try:
for i in range(len(data)):
current_byte_enable = next(byte_enable_iterator)
strobes.append(2**size - 1 if current_byte_enable is None
else current_byte_enable)
except StopIteration:
# Fill the remaining strobes with the last one if we have reached
# the end of the iterator
strobes += [strobes[-1]] * (len(data) - i)
# Unalign the words and strobes (if unaligned and not FIXED)
if address % size != 0:
shift = (address % size) * 8
if burst is AXIBurst.FIXED:
data = [(word << shift) & (2**(size * 8) - 1) for word in data]
strobes = \
[(strb << shift // 8) & (2**size - 1) for strb in strobes]
else:
data = unalign_data(data, size * 8, shift)
strobes = unalign_data(strobes, size, address % size)
async with self.write_data_busy:
if sync:
await RisingEdge(self.clock)
wdata_bytes = len(self.bus.WDATA) // 8
narrow_block = (address % wdata_bytes) // size
for beat_num, (word, strobe) in enumerate(zip(data, strobes)):
await ClockCycles(self.clock, delay)
self.bus.WVALID <= 1
self.bus.WDATA <= mask_and_shift(word, size * 8, narrow_block)
self.bus.WSTRB <= mask_and_shift(strobe, size, narrow_block)
if burst is not AXIBurst.FIXED:
narrow_block = (narrow_block + 1) % (wdata_bytes // size)
if hasattr(self.bus, "WLAST"):
if beat_num == len(data) - 1:
self.bus.WLAST <= 1
else:
self.bus.WLAST <= 0
while True:
await RisingEdge(self.clock)
if self.bus.WREADY.value:
break
if beat_num == len(data) - 1:
self.bus.WVALID <= 0
[docs] @cocotb.coroutine
async def write(
self, address: int, value: Union[int, Sequence[int]], *,
size: Optional[int] = None, burst: AXIBurst = AXIBurst.INCR,
byte_enable: Union[Optional[int], Sequence[Optional[int]]] = None,
address_latency: int = 0, data_latency: int = 0, sync: bool = True
) -> None:
"""Write a value to an address.
With unaligned writes (when ``address`` is not a multiple of ``size``),
only the low ``size - address % size`` Bytes are written for:
* the last element of ``value`` for INCR or WRAP bursts, or
* every element of ``value`` for FIXED bursts.
Args:
address: The address to write to.
value: The data value(s) to write.
size: The size (in bytes) of each beat. Defaults to None (width of
the data bus).
burst: The burst type, either ``FIXED``, ``INCR`` or ``WRAP``.
Defaults to ``INCR``.
byte_enable: Which bytes in value to actually write. Defaults to
None (write all bytes).
address_latency: Delay before setting the address (in clock
cycles). Default is no delay.
data_latency: Delay before setting the data value (in clock
cycles).
Default is no delay.
sync: Wait for rising edge on clock initially.
Defaults to True.
Raises:
ValueError: If any of the input parameters is invalid.
AXIProtocolError: If write response from AXI is not ``OKAY``.
"""
if not isinstance(value, collections.abc.Sequence):
value = (value,) # If value is not a sequence, make it
if not isinstance(byte_enable, collections.abc.Sequence):
byte_enable = (byte_enable,) # Same for byte_enable
if size is None:
size = len(self.bus.WDATA) // 8
else:
AXI4Master._check_size(size, len(self.bus.WDATA) // 8)
AXI4Master._check_length(len(value), burst)
AXI4Master._check_4kB_boundary_crossing(address, burst, size,
len(value))
write_address = self._send_write_address(address, len(value), burst,
size, address_latency, sync)
write_data = self._send_write_data(address, value, burst, size,
data_latency, byte_enable, sync)
await Combine(cocotb.fork(write_address), cocotb.fork(write_data))
async with self.write_response_busy:
# Wait for the response
while True:
await ReadOnly()
if self.bus.BVALID.value and self.bus.BREADY.value:
result = AXIxRESP(self.bus.BRESP.value.integer)
break
await RisingEdge(self.clock)
await RisingEdge(self.clock)
if result is not AXIxRESP.OKAY:
err_msg = "Write to address {0:#x}"
if len(value) != 1:
err_msg += " ({1} beats, {2} burst)"
err_msg += " failed with BRESP: {3} ({4})"
raise AXIProtocolError(
err_msg.format(address, len(value), burst.name,
result.value, result.name), result)
[docs] @cocotb.coroutine
async def read(
self, address: int, length: int = 1, *,
size: Optional[int] = None, burst: AXIBurst = AXIBurst.INCR,
return_rresp: bool = False, sync: bool = True
) -> Union[List[BinaryValue], List[Tuple[BinaryValue, AXIxRESP]]]:
"""Read from an address.
With unaligned reads (when ``address`` is not a multiple of ``size``)
with INCR or WRAP bursts, the last element of the returned read data
will be only the low-order ``size - address % size`` Bytes of the last
word.
With unaligned reads with FIXED bursts, every element of the returned
read data will be only the low-order ``size - address % size`` Bytes.
Args:
address: The address to read from.
length: Number of words to transfer. Defaults to 1.
size: The size (in bytes) of each beat. Defaults to None (width of
the data bus).
burst: The burst type, either ``FIXED``, ``INCR`` or ``WRAP``.
Defaults to ``INCR``.
return_rresp: Return the list of RRESP values, instead of raising
an AXIProtocolError in case of not OKAY. Defaults to False.
sync: Wait for rising edge on clock initially. Defaults to True.
Returns:
The read data values or, if *return_rresp* is True, a list of pairs
each containing the data and RRESP values.-
Raises:
ValueError: If any of the input parameters is invalid.
AXIProtocolError: If read response from AXI is not ``OKAY`` and
*return_rresp* is False
AXIReadBurstLengthMismatch: If the received number of words does
not match the requested one.
"""
# Helper function for narrow bursts
def shift_and_mask(binvalue: BinaryValue, bytes_num: int,
byte_shift: int) -> BinaryValue:
start = byte_shift * 8
end = (bytes_num + byte_shift) * 8
return binvalue[len(binvalue) - end:len(binvalue) - start - 1]
# [0x221100XX, 0x66554433] --> [0x33221100, 0x665544]
def realign_data(
data: Sequence[BinaryValue], size_bits: int, shift: int
) -> List[BinaryValue]:
binstr_join = "".join([word.binstr[::-1] for word in data])
binstr_join = binstr_join[shift:]
data_binstr = [binstr_join[i * size_bits:(i + 1) * size_bits][::-1]
for i in range(len(data))]
return [BinaryValue(value=binstr, n_bits=len(binstr))
for binstr in data_binstr]
if size is None:
size = len(self.bus.RDATA) // 8
else:
AXI4Master._check_size(size, len(self.bus.RDATA) // 8)
AXI4Master._check_length(length, burst)
AXI4Master._check_4kB_boundary_crossing(address, burst, size, length)
rdata_bytes = len(self.bus.RDATA) // 8
byte_offset = (address % rdata_bytes) // size * size
async with self.read_address_busy:
if sync:
await RisingEdge(self.clock)
self.bus.ARADDR <= address
self.bus.ARVALID <= 1
if hasattr(self.bus, "ARLEN"):
self.bus.ARLEN <= length - 1
if hasattr(self.bus, "ARSIZE"):
self.bus.ARSIZE <= size.bit_length() - 1
if hasattr(self.bus, "ARBURST"):
self.bus.ARBURST <= burst.value
while True:
await ReadOnly()
if self.bus.ARREADY.value:
break
await RisingEdge(self.clock)
await RisingEdge(self.clock)
self.bus.ARVALID <= 0
async with self.read_data_busy:
data = []
rresp = []
for beat_num in itertools.count():
while True:
await ReadOnly()
if self.bus.RVALID.value and self.bus.RREADY.value:
# Shift and mask to correctly handle narrow bursts
beat_value = shift_and_mask(self.bus.RDATA.value,
size, byte_offset)
data.append(beat_value)
rresp.append(AXIxRESP(self.bus.RRESP.value.integer))
if burst is not AXIBurst.FIXED:
byte_offset = (byte_offset + size) % rdata_bytes
break
await RisingEdge(self.clock)
if not hasattr(self.bus, "RLAST") or self.bus.RLAST.value:
break
await RisingEdge(self.clock)
await RisingEdge(self.clock)
if len(data) != length:
raise AXIReadBurstLengthMismatch(
"AXI4 slave returned {} data than expected (requested {} "
"words, received {})"
.format("more" if len(data) > length else "less",
length, len(data)))
# Re-align the words
if address % size != 0:
shift = (address % size) * 8
if burst is AXIBurst.FIXED:
data = [word[0:size * 8 - shift - 1] for word in data]
else:
data = realign_data(data, size * 8, shift)
if return_rresp:
return list(zip(data, rresp))
else:
for beat_number, beat_result in enumerate(rresp):
if beat_result is not AXIxRESP.OKAY:
err_msg = "Read on address {0:#x}"
if length != 1:
err_msg += " (beat {1} of {2}, {3} burst)"
err_msg += " failed with RRESP: {4} ({5})"
err_msg = err_msg.format(
address, beat_number + 1, length, burst,
beat_result.value, beat_result.name)
raise AXIProtocolError(err_msg, beat_result)
return data
def __len__(self):
return 2**len(self.bus.ARADDR)
[docs]class AXI4LiteMaster(AXI4Master):
"""AXI4-Lite Master"""
_signals = ["AWVALID", "AWADDR", "AWREADY", # Write address channel
"WVALID", "WREADY", "WDATA", "WSTRB", # Write data channel
"BVALID", "BREADY", "BRESP", # Write response channel
"ARVALID", "ARADDR", "ARREADY", # Read address channel
"RVALID", "RREADY", "RRESP", "RDATA"] # Read data channel
_optional_signals = []
[docs] @cocotb.coroutine
async def write(
self, address: int, value: int, byte_enable: Optional[int] = None,
address_latency: int = 0, data_latency: int = 0, sync: bool = True
) -> BinaryValue:
"""Write a value to an address.
Args:
address: The address to write to.
value: The data value to write.
byte_enable: Which bytes in value to actually write. Defaults to
None (write all bytes).
address_latency: Delay before setting the address (in clock
cycles). Default is no delay.
data_latency: Delay before setting the data value (in clock
cycles). Default is no delay.
sync: Wait for rising edge on clock initially. Defaults to True.
Returns:
The write response value.
Raises:
ValueError: If any of the input parameters is invalid.
AXIProtocolError: If write response from AXI is not ``OKAY``.
"""
if isinstance(value, collections.abc.Sequence):
raise ValueError("AXI4-Lite does not support burst transfers")
await super().write(
address=address, value=value, size=None, burst=AXIBurst.INCR,
byte_enable=byte_enable, address_latency=address_latency,
data_latency=data_latency, sync=sync)
# Needed for backwards compatibility
return BinaryValue(value=AXIxRESP.OKAY.value, n_bits=2)
[docs] @cocotb.coroutine
async def read(self, address: int, sync: bool = True) -> BinaryValue:
"""Read from an address.
Args:
address: The address to read from.
length: Number of words to transfer
sync: Wait for rising edge on clock initially. Defaults to True.
Returns:
The read data value.
Raises:
AXIProtocolError: If read response from AXI is not ``OKAY``.
"""
ret = await super().read(address=address, length=1, size=None,
burst=AXIBurst.INCR, return_rresp=False,
sync=sync)
return ret[0]
[docs]class AXI4Slave(BusDriver):
'''
AXI4 Slave
Monitors an internal memory and handles read and write requests.
'''
_signals = [
"ARREADY", "ARVALID", "ARADDR", # Read address channel
"ARLEN", "ARSIZE", "ARBURST", "ARPROT",
"RREADY", "RVALID", "RDATA", "RLAST", # Read response channel
"AWREADY", "AWADDR", "AWVALID", # Write address channel
"AWPROT", "AWSIZE", "AWBURST", "AWLEN",
"WREADY", "WVALID", "WDATA",
]
# Not currently supported by this driver
_optional_signals = [
"WLAST", "WSTRB",
"BVALID", "BREADY", "BRESP", "RRESP",
"RCOUNT", "WCOUNT", "RACOUNT", "WACOUNT",
"ARLOCK", "AWLOCK", "ARCACHE", "AWCACHE",
"ARQOS", "AWQOS", "ARID", "AWID",
"BID", "RID", "WID"
]
def __init__(self, entity, name, clock, memory, callback=None, event=None,
big_endian=False, **kwargs):
BusDriver.__init__(self, entity, name, clock, **kwargs)
self.clock = clock
self.big_endian = big_endian
self.bus.ARREADY.setimmediatevalue(1)
self.bus.RVALID.setimmediatevalue(0)
self.bus.RLAST.setimmediatevalue(0)
self.bus.AWREADY.setimmediatevalue(1)
self._memory = memory
self.write_address_busy = Lock("%s_wabusy" % name)
self.read_address_busy = Lock("%s_rabusy" % name)
self.write_data_busy = Lock("%s_wbusy" % name)
cocotb.fork(self._read_data())
cocotb.fork(self._write_data())
def _size_to_bytes_in_beat(self, AxSIZE):
if AxSIZE < 7:
return 2 ** AxSIZE
return None
async def _write_data(self):
clock_re = RisingEdge(self.clock)
while True:
while True:
self.bus.WREADY <= 0
await ReadOnly()
if self.bus.AWVALID.value:
self.bus.WREADY <= 1
break
await clock_re
await ReadOnly()
_awaddr = int(self.bus.AWADDR)
_awlen = int(self.bus.AWLEN)
_awsize = int(self.bus.AWSIZE)
_awburst = int(self.bus.AWBURST)
_awprot = int(self.bus.AWPROT)
burst_length = _awlen + 1
bytes_in_beat = self._size_to_bytes_in_beat(_awsize)
if __debug__:
self.log.debug(
"AWADDR %d\n" % _awaddr +
"AWLEN %d\n" % _awlen +
"AWSIZE %d\n" % _awsize +
"AWBURST %d\n" % _awburst +
"AWPROT %d\n" % _awprot +
"BURST_LENGTH %d\n" % burst_length +
"Bytes in beat %d\n" % bytes_in_beat)
burst_count = burst_length
await clock_re
while True:
if self.bus.WVALID.value:
word = self.bus.WDATA.value
word.big_endian = self.big_endian
_burst_diff = burst_length - burst_count
_st = _awaddr + (_burst_diff * bytes_in_beat) # start
_end = _awaddr + ((_burst_diff + 1) * bytes_in_beat) # end
self._memory[_st:_end] = array.array('B', word.buff)
burst_count -= 1
if burst_count == 0:
break
await clock_re
async def _read_data(self):
clock_re = RisingEdge(self.clock)
while True:
while True:
await ReadOnly()
if self.bus.ARVALID.value:
break
await clock_re
await ReadOnly()
_araddr = int(self.bus.ARADDR)
_arlen = int(self.bus.ARLEN)
_arsize = int(self.bus.ARSIZE)
_arburst = int(self.bus.ARBURST)
_arprot = int(self.bus.ARPROT)
burst_length = _arlen + 1
bytes_in_beat = self._size_to_bytes_in_beat(_arsize)
word = BinaryValue(n_bits=bytes_in_beat*8, bigEndian=self.big_endian)
if __debug__:
self.log.debug(
"ARADDR %d\n" % _araddr +
"ARLEN %d\n" % _arlen +
"ARSIZE %d\n" % _arsize +
"ARBURST %d\n" % _arburst +
"ARPROT %d\n" % _arprot +
"BURST_LENGTH %d\n" % burst_length +
"Bytes in beat %d\n" % bytes_in_beat)
burst_count = burst_length
await clock_re
while True:
self.bus.RVALID <= 1
await ReadOnly()
if self.bus.RREADY.value:
_burst_diff = burst_length - burst_count
_st = _araddr + (_burst_diff * bytes_in_beat)
_end = _araddr + ((_burst_diff + 1) * bytes_in_beat)
word.buff = self._memory[_st:_end].tobytes()
self.bus.RDATA <= word
if burst_count == 1:
self.bus.RLAST <= 1
await clock_re
burst_count -= 1
self.bus.RLAST <= 0
if burst_count == 0:
break