Source code for qrobot.qunits.qunit

import json
from typing import Dict, List, Optional

from ..bursts import Burst
from ..models import Model
from . import redis_utils
from .base import BaseUnit


[docs]class QUnit(BaseUnit): # pylint: disable=too-many-instance-attributes """[QUnit description] Parameters ------------ name : str The qUnit name model : qrobot.models.Model The model the qUnit implements burst : qrobot.bursts.Burst The burst the qUnit implements Ts : float The sampling time with wich the qUnit reads the input query : list, optional The target state for the model queries. Defaults to ``None`` in_units : dict[int, str], optional Dictionary containing {``dim`` : ``qunit_id``} inputs couplings, i.e. ``qunit_id`` output is the input for dimension ``dim``. Defaults to ``None``. default_input: List[float] Default input vector of scalar values to use as default value when qunit does not have an available one. Defaults to ``model.n*[0.0]`` Attributes ---------- id : str The unique instance identifier of the qUnit name : str The unique instance identifier of the qUnit model : qrobot.models.Model The model which the qUnit implements burst : qrobot.bursts.Burst The burst the qUnit implements Ts : float The sampling period for which the qUnit samples an event default_input: List[float] Default input vector of scalar values to use as default value when qunit does not have an available one """ def __init__( # pylint: disable=too-many-arguments self, name: str, model: Model, burst: Burst, Ts: float, # pylint: disable=invalid-name query: list = None, in_qunits: Dict[int, str] = None, default_input: float = None, ) -> None: # Call the BaseUnit constructor super().__init__(name, Ts) # Store the qUnits name and properties self.model = model self.burst = burst self.default_input = default_input or model.n * [0.0] # Default query to all 0s if not specified if query is None: query = [0.0] * (self.model.n) # Initialize multiprocessing variables # - Query array variable self._query = self._multiproc_manager.list(query) # - Output unit dictionary self._in_qunits = self._multiproc_manager.dict(in_qunits or {}) # - Time window index self._t_idx = self._multiproc_manager.Value("i", 0) # Log properties self._logger.debug(f"Properties: {self}") def __iter__(self): yield "name", self.name yield "id", self.id yield "model", str(self.model) yield "burst", str(self.burst.__class__) yield "query", self.query yield "Ts", self.Ts @property def query(self) -> List[float]: """Current target state for the model queries Returns ------- list The query target state array in the computational basis """ return list(self._query) @query.setter def query(self, query: list) -> None: """Set a new query state for the qunit Parameters ----------- query : list The query target state array in the computational basis """ # Check arguments query = self.model._target_vector_check(query) # Update accumulator self._logger.debug(f"Changing query from {self._query} to {query}") for idx, value in enumerate(query): self._query[idx] = value self._logger.debug(f"_query={self._query}") @property def in_qunits(self) -> Dict[int, str]: """Current output ``{dim : qunit_id}`` couplings. Returns ------- dict The current output ``{dim : qunit_id}`` couplings dictionary """ in_qunits = {} for dim in range(self.model.n): try: in_qunits[dim] = self._in_qunits[dim] except KeyError: in_qunits[dim] = None return in_qunits @property def input_vector(self) -> List[float]: """The current input vector of the unit Returns ------- list The current input vector """ input_vector = self.default_input for dim, qunit_id in self._in_qunits.items(): _r = redis_utils.get_redis() val = _r.get(qunit_id + " output") if val is not None: input_vector[dim] = float(val) else: self._logger.info(f"Unable to read {qunit_id} input") return input_vector
[docs] def set_input(self, dim: int, qunit_id: str) -> None: """Set a new input qunit for the desired dimension Parameters ----------- dim : int The input dimension index qunit_id : str The input qunit id """ # Check arguments dim = self.model._dim_index_check(dim) # pylint: disable=protected-access # Update accumulator self._logger.debug( f"Changing dim {dim} input from " + f"{self.in_qunits[dim]} to {qunit_id}" ) self._in_qunits[dim] = qunit_id self._logger.debug(f"_in_qunits={self._in_qunits}")
[docs] def get_burst_output(self) -> Optional[float]: """Get the latest burst output from the qUnit Returns ------- float The latest burst output written by the unit on the Redis database """ global_status = redis_utils.redis_status() out = global_status.get(f"{self.id} output", None) return float(out) if out else None
def _clean_redis(self) -> None: """Clean all the redis entries created by the unit when the loop stops.""" _r = redis_utils.get_redis() _r.delete(self.id + " output") _r.delete(self.id + " state") _r.delete(self.id + " query") _r.delete(self.id + " in_qunits") def _unit_task(self) -> None: """Single iteration of the processing loop.""" # "_t_idx" is the event index of the temporal window self._logger.debug( f"Temporal window event {self._t_idx.value+1}/{self.model.tau}" ) # Get input input_vector = self.input_vector self._logger.debug(f"input_vector={input_vector}") # Loop through the dimensions to encode data for dim in range(self.model.n): self.model.encode(input_vector[dim], dim) # Wait for the next input in the time window self._t_idx.value += 1 # If at the end of the time window if self._t_idx.value == self.model.tau: # Apply the query self._logger.debug(f"Querying for state {self._query}") self.model.query(self.query) # Decode out_state = self.model.decode() self._logger.debug(f"Output state = {out_state}") # Write output on Redis database self._logger.debug("Opening a connection to redis...") _r = redis_utils.get_redis() self._logger.debug(f"Redis connected: {_r}") if not ( _r.mset({self.id + " output": self.burst(out_state)}) and _r.mset({self.id + " state": str(out_state)}) and _r.mset({self.id + " query": json.dumps(self.query)}) and _r.mset({self.id + " in_qunits": json.dumps(self.in_qunits)}) ): raise Exception( f"Problem in writing qunit {self.id} output on Redis database!" ) # Initialize new temporal window self._logger.debug("Initializing a new temporal window") self.model.clear() self._t_idx.value = 0