# standard library imports
from collections import namedtuple
from enum import StrEnum
from dataclasses import dataclass
from math import fsum
from typing import ClassVar, Generator, List, Optional, Union
# third party imports
import h5py
from mantid.api import AnalysisDataService
from mantid.dataobjects import EventWorkspace
from mantid.simpleapi import CreateSingleValuedWorkspace, DeleteWorkspace, logger, mtd, RenameWorkspace
import numpy as np
# drtsans imports
from drtsans.api import _set_uncertainty_from_numpy
from drtsans.path import abspath
from drtsans.samplelogs import SampleLogs
from drtsans.type_hints import MantidWorkspace
class PolarizationLevel(StrEnum):
NONE = "none" # also 0
HALF = "half" # also 1
FULL = "full" # also 2
@classmethod
def from_int(cls, level: int) -> "PolarizationLevel":
r"""
Convert an integer polarization mode to a PolarizationLevel enum.
Parameters
----------
level : int
Integer representation of the polarization mode:
- 0: OFF (unpolarized)
- 1: HALF (polarizer is active)
- 2: FULL (both polarizer and analyzer are active)
Raises
------
ValueError
If the input integer does not correspond to a valid polarization mode.
"""
if level == 0:
return cls.NONE
elif level == 1:
return cls.HALF
elif level == 2:
return cls.FULL
else:
raise ValueError(f"Invalid polarization mode integer: {level}. Must be 0, 1, or 2.")
@classmethod
def get(cls, source: Union[str, EventWorkspace]) -> "PolarizationLevel":
r"""
Find if a run is polarized, and to what degree.
Parameters
----------
source : str, EventWorkspace
Either an instrument plus run number identifier (e.g., CG2_1235), a file name in the current directory,
an absolute file path, the name of an events workspace, or an EventWorkspace object.
Raises
------
TypeError
If the input is neither a string nor an EventWorkspace.
"""
mode = 0
# Determines polarization mode from workspace or file metadata
if isinstance(source, EventWorkspace):
sample_logs = SampleLogs(source)
if PV_POLARIZER in sample_logs:
mode += int(sample_logs.single_value(PV_POLARIZER))
if PV_ANALYZER in sample_logs:
mode += int(sample_logs.single_value(PV_ANALYZER))
elif isinstance(source, str):
# case 1: `source` is the name of a workspace in the AnalysisDataService
if AnalysisDataService.doesExist(source):
if isinstance(mtd[source], EventWorkspace):
sample_logs = SampleLogs(source)
if PV_POLARIZER in sample_logs:
mode += int(sample_logs.single_value(PV_POLARIZER))
if PV_ANALYZER in sample_logs:
mode += int(sample_logs.single_value(PV_ANALYZER))
else:
raise TypeError(f"The workspace '{source}' is not an EventWorkspace")
else:
# case 2: `source` is a file path
filepath = abspath(source)
with h5py.File(filepath, "r") as file_handle:
dataset = file_handle.get(f"/entry/DASlogs/{PV_POLARIZER}")
if dataset is not None: # log not found, assume unpolarized
mode += int(dataset["value"][()])
dataset = file_handle.get(f"/entry/DASlogs/{PV_ANALYZER}")
if dataset is not None:
mode += int(dataset["value"][()])
else:
raise TypeError(f"{source} must be either a string or an EventWorkspace object")
return cls.from_int(mode)
class PolarizationCrossSection(StrEnum):
"""Enumerate the possible spin cross-section states based on flipper and analyzer status."""
NONE = "none" # no polarizer and no analyzer
OFF = "off" # flipper off, no analyzer
ON = "on" # flipper on, no analyzer
OFF_OFF = "off_off" # flipper off, analyzer off
OFF_ON = "off_on" # flipper off, analyzer on
ON_OFF = "on_off" # flipper on, analyzer off
ON_ON = "on_on" # flipper on, analyzer on
@classmethod
def get(cls, workspace: MantidWorkspace) -> "PolarizationCrossSection":
"""Retrieve the polarization cross-section from the sample logs of the given workspace."""
return cls(str(SampleLogs(workspace).single_value(cls.logname)))
def log(self, workspace: MantidWorkspace):
"""Insert the polarization cross-section into the sample logs of the given workspace."""
SampleLogs(workspace).insert(name=self.__class__.logname, value=self.value)
@property
def level(self) -> PolarizationLevel:
if self == PolarizationCrossSection.NONE:
return PolarizationLevel.NONE
if self in {PolarizationCrossSection.OFF, PolarizationCrossSection.ON}:
return PolarizationLevel.HALF
return PolarizationLevel.FULL
# Add class variable after the enum definition
PolarizationCrossSection.logname = "polarization: cross-section"
class PolarizationState(StrEnum):
"""Enumerate the possible spin states based on the upstream and downstream neutrons"""
NONE = "none" # no polarization
UP = "up" # half polarization, upstream spin up
DOWN = "down" # half polarization, upstream spin down
UP_UP = "up_up" # full polarization, upstream spin up, downstream spin up
UP_DOWN = "up_down" # full polarization, upstream spin up, downstream spin down
DOWN_UP = "down_up" # full polarization, upstream spin down, downstream spin up
DOWN_DOWN = "down_down" # full polarization, upstream spin down, downstream spin down
@classmethod
def get(cls, workspace: MantidWorkspace) -> "PolarizationState":
"""Retrieve the polarization state from the sample logs of the given workspace."""
return cls(str(SampleLogs(workspace).single_value(cls.logname)))
def log(self, workspace: MantidWorkspace):
"""Insert the polarization state into the sample logs of the given workspace."""
SampleLogs(workspace).insert(name=self.__class__.logname, value=self.value)
@property
def level(self) -> PolarizationLevel:
if self == PolarizationState.NONE:
return PolarizationLevel.NONE
if self in {PolarizationState.UP, PolarizationState.DOWN}:
return PolarizationLevel.HALF
return PolarizationLevel.FULL
# Add class variable after the enum definition
PolarizationState.logname = "polarization: state"
__all__ = [
"PV_POLARIZER",
"PV_POLARIZER_FLIPPER",
"PV_POLARIZER_VETO",
"PV_ANALYZER",
"PV_ANALYZER_FLIPPER",
"PV_ANALYZER_VETO",
"half_polarization",
"SimulatedPolarizationLogs",
]
# Names of processing variables related to polarization, stored in the sample logs of the Nexus Event file
# For the moment, we assume these PVs are consistent across instruments.
PV_POLARIZER = "Polarizer"
PV_POLARIZER_FLIPPER = "PolarizerFlipper"
PV_POLARIZER_VETO = "PolarizerVeto"
PV_ANALYZER = "Analyzer"
PV_ANALYZER_FLIPPER = "AnalyzerFlipper"
PV_ANALYZER_VETO = "AnalyzerVeto"
def polarized_sample(reduction_parameters: dict) -> bool:
r"""
Determine if the sample run involves polarized neutrons.
This function checks for polarization under `configuration['polarization']['level']`.
If the settings are missing, it examines the sample metadata to determine the polarization level.
Parameters
----------
reduction_parameters : dict
Dictionary of reduction configuration parameters. It can be either the full reduction input containing
the "configuration" key or just the reduction configuration itself.
Returns
-------
bool
True if the sample is polarized (polarization level is not NONE), False otherwise.
Raises
------
ValueError
If multiple sample runs are provided and the first sample represents a polarized run.
Currently, we can't reduce polarization for summed datasets.
Notes
-----
The function modifies the reduction configuration by setting the polarization level in
`configuration['polarization']['level']` if not previously specified.
"""
if "configuration" in reduction_parameters:
reduction_config = reduction_parameters["configuration"]
directories = reduction_parameters.get("dataDirectories", None)
else:
reduction_config = reduction_parameters
directories = None
if reduction_config.get("polarization", {}).get("level", None) is None:
if reduction_config == reduction_parameters:
logger.warning("Unable to resolve polarization level. Setting to NONE by default.")
reduction_config["polarization"] = {"level": str(PolarizationLevel.NONE)}
else:
sample = reduction_parameters["sample"]["runNumber"].strip()
multiple_samples = len(sample.split(",")) > 1
if multiple_samples:
sample = sample.split(",")[0] # inquire from the first sample
sample_filepath = abspath(
sample,
instrument=reduction_parameters["instrumentName"],
ipts=reduction_parameters["iptsNumber"],
directory=directories,
search_archive=True,
)
level = PolarizationLevel.get(sample_filepath)
if multiple_samples and level != PolarizationLevel.NONE:
raise ValueError("Can't do polarization reduction on summed data sets")
reduction_config["polarization"] = {"level": str(level)}
return reduction_config["polarization"]["level"] != PolarizationLevel.NONE
def _calc_flipping_ratio(polarization):
"""Calculates the flipping ratio from the polarization state
Parameters
----------
polarization: str, ~mantid.api.MatrixWorkspace
Polarization state
Returns
-------
~mantid.api.MatrixWorkspace
The ratio of flipping
"""
value = polarization.extractY()[0]
uncertainty = polarization.extractE()[0]
if len(value) == 1 and len(uncertainty) == 1:
value, uncertainty = value[0], uncertainty[0]
uncertainty = 2.0 * uncertainty / np.square(1 - value)
value = (1 + value) / (1 - value)
if isinstance(value, float): # create a single value workspace
return CreateSingleValuedWorkspace(
DataValue=value,
ErrorValue=uncertainty,
OutputWorkspace=mtd.unique_hidden_name(),
EnableLogging=False,
)
else:
# if it is an array should call CreateWorkspace(EnableLogging=False)
raise NotImplementedError("Somebody needs to create an output from {} (type={})".format(value, type(value)))
def _calc_half_polarization_up(flipper_off, flipper_on, efficiency, flipping_ratio):
"""This calculates the spin up workspace
Parameters
----------
flipper_off_workspace: str, ~mantid.api.MatrixWorkspace
Flipper off measurement
flipper_on_workspace: str, ~mantid.api.MatrixWorkspace
Flipper on measurement
efficiency: ~mantid.api.MatrixWorkspace
Flipper efficiency
flipping_ratio: ~mantid.api.MatrixWorkspace
The ratio of flipping
Returns
-------
~mantid.api.MatrixWorkspace
The spin up workspace
"""
__spin_up = flipper_off + (flipper_off - flipper_on) / (efficiency * (flipping_ratio - 1.0))
e = efficiency.extractY()[0]
F = flipping_ratio.extractY()[0]
# the uncertainties (numerically) aren't correct because of build-up of numerical errors.
# Recalculate them based on the proper equation based on a hand calculation of the partial derivatives
m0_part = np.square(flipper_off.extractE()[0][0] * (1 + 1 / (e * (F - 1))))
m1_part = np.square(flipper_on.extractE()[0][0] * (1 / (e * (F - 1))))
mixed = np.square((flipper_off.extractY()[0][0] - flipper_on.extractY()[0][0]) / (e * (F - 1)))
mixed *= np.square(efficiency.extractE()[0][0] / e) + np.square(flipping_ratio.extractE()[0][0] / (F - 1))
sup_err = np.sqrt(m0_part + m1_part + mixed)
# set the uncertainty in the workspace
__spin_up = _set_uncertainty_from_numpy(__spin_up, sup_err)
return __spin_up
def _calc_half_polarization_down(flipper_off, flipper_on, efficiency, flipping_ratio):
"""This calculates the spin down workspace
Parameters
----------
flipper_off_workspace: str, ~mantid.api.MatrixWorkspace
Flipper off measurement
flipper_on_workspace: str, ~mantid.api.MatrixWorkspace
Flipper on measurement
efficiency: ~mantid.api.MatrixWorkspace
Flipper efficiency
flipping_ratio: ~mantid.api.MatrixWorkspace
The ratio of flipping
Returns
-------
~mantid.api.MatrixWorkspace
The spin down workspace
"""
__spin_down = flipper_off - (flipper_off - flipper_on) / (efficiency * (1.0 - 1.0 / flipping_ratio))
e = efficiency.extractY()[0]
F = flipping_ratio.extractY()[0]
# the uncertainties (numerically) aren't correct because of build-up of numerical errors.
# Recalculate them based on the proper equation based on a hand calculation of the partial derivatives
m0_part = np.square(flipper_off.extractE()[0][0] * (1 - 1 / (e * (1 - 1 / F))))
m1_part = np.square(flipper_on.extractE()[0][0] * (1 / (e * (1 - 1 / F))))
mixed = np.square((flipper_off.extractY()[0][0] - flipper_on.extractY()[0][0]) / (e * (1 - 1 / F)))
mixed *= np.square(efficiency.extractE()[0][0] / e) + np.square(
flipping_ratio.extractE()[0][0] / (F * F * (1 - 1 / F))
)
sdn_err = np.sqrt(m0_part + m1_part + mixed)
# set the uncertainty in the workspace
__spin_down = _set_uncertainty_from_numpy(__spin_down, sdn_err)
return __spin_down
[docs]
def half_polarization(
flipper_off_workspace,
flipper_on_workspace,
polarization,
efficiency,
spin_up_workspace=None,
spin_down_workspace=None,
):
"""Calculate the spin up/down workspaces from flipper on/off.
**Mantid algorithms used:**
:ref:`RenameWorkspace <algm-RenameWorkspace-v1>`
Parameters
----------
flipper_off_workspace: str, ~mantid.api.MatrixWorkspace
Flipper off measurement
flipper_on_workspace: str, ~mantid.api.MatrixWorkspace
Flipper on measurement
polarization: str, ~mantid.api.MatrixWorkspace
Polarization state
efficiency: str, ~mantid.api.MatrixWorkspace
Flipper efficiency
spin_up_workspace: str
Name of the resulting spin up workspace. If :py:obj:`None`, then
``flipper_off_workspace`` will be overwritten.
spin_down_workspace: str
Name of the resulting spin down workspace. If :py:obj:`None`, then
``flipper_on_workspace`` will be overwritten.
Returns
-------
py:obj:`tuple` of 2 ~mantid.api.MatrixWorkspace
"""
if spin_up_workspace is None:
spin_up_workspace = str(flipper_off_workspace)
if spin_down_workspace is None:
spin_down_workspace = str(flipper_on_workspace)
# this is denoted as "F" in the master document
flipping_ratio = _calc_flipping_ratio(polarization)
__spin_up = _calc_half_polarization_up(flipper_off_workspace, flipper_on_workspace, efficiency, flipping_ratio)
__spin_down = _calc_half_polarization_down(flipper_off_workspace, flipper_on_workspace, efficiency, flipping_ratio)
spin_up_workspace = RenameWorkspace(InputWorkspace=__spin_up, OutputWorkspace=spin_up_workspace)
spin_down_workspace = RenameWorkspace(InputWorkspace=__spin_down, OutputWorkspace=spin_down_workspace)
DeleteWorkspace(flipping_ratio)
return (spin_up_workspace, spin_down_workspace)
# A simple way to encode the name and specifications for one of the time generators methods of class SimulatedLogs
# Example: polarizer_veto=TimesGeneratorSpecs("binary_pulse", {"interval": 1.0, "alive_duration": 0.2})
TimesGeneratorSpecs = namedtuple("TimesGeneratorSpecs", ["name", "kwargs"])
[docs]
@dataclass
class SimulatedPolarizationLogs:
"""A simulated log for testing purposes."""
polarizer: int = 0
polarizer_flipper: Optional[TimesGeneratorSpecs] = None
polarizer_veto: Optional[TimesGeneratorSpecs] = None
analyzer: int = 0
analyzer_flipper: Optional[TimesGeneratorSpecs] = None
analyzer_veto: Optional[TimesGeneratorSpecs] = None
# Class variables
flipper_generators: ClassVar[List[str]] = ["heartbeat", "binary_pulse", "cycled_intervals"]
veto_generators: ClassVar[List[str]] = ["binary_pulse"] # available time-generators for veto intervals
def __post_init__(self):
# Validate input polarizer and analyzer flipper generators
for flipper, device in zip([self.polarizer_flipper, self.analyzer_flipper], ["polarizer", "analyzer"]):
if flipper and flipper.name not in self.flipper_generators:
raise ValueError(
f"The {device} flipper generator must be one of {self.flipper_generators}, got '{flipper.name}'"
)
# Validate input polarizer and analyzer veto generators
for veto, device in zip([self.polarizer_veto, self.analyzer_veto], ["polarizer", "analyzer"]):
if veto and veto.name not in self.veto_generators:
raise ValueError(
f"The {device} veto generator must be one of {self.veto_generators}, got '{veto.name}'"
)
[docs]
def heartbeat(
self, interval: float, dead_time: Optional[float] = 0.0, upper_bound: Optional[float] = None
) -> Generator[float, None, None]:
"""
Generate a sequence of timestamps at regular intervals, starting at or later than dead_time.
Parameters
----------
interval : float
The time interval between consecutive timestamps, in seconds.
dead_time: float
The initial time period, in seconds, during which no times are generated. Defaults to 0.0.
upper_bound : float, optional
The maximum time value to generate, in seconds. If None, the generator will continue indefinitely.
Yields
------
float
The next timestamp in the sequence.
Examples
--------
>>> log = SimulatedPolarizationLogs()
>>> list(log.heartbeat(interval=1.0, upper_bound=5.0))
[0, 1.0, 2.0, 3.0, 4.0, 5.0]
"""
elapsed = 0
while elapsed <= upper_bound if upper_bound is not None else True:
if elapsed >= dead_time:
yield elapsed
elapsed += interval
[docs]
def cycled_intervals(
self,
intervals: List[float],
dead_time: Optional[float] = 0.0,
upper_bound: Optional[float] = None,
) -> Generator[float, None, None]:
"""
Generate a sequence of timestamps by repeatedly cycling through a list of time intervals.
Starting from time zero, each successive timestamp is computed by adding the next interval
in ``intervals`` to the current elapsed time. Once all intervals have been consumed, the
cycle restarts from the first interval. Timestamps are only yielded when at or after
``dead_time``.
Parameters
----------
intervals : list of float
A list of time intervals, in seconds, to cycle through repeatedly.
dead_time : float, optional
The initial time period, in seconds, during which no times are generated. Defaults to 0.0.
upper_bound : float, optional
The maximum time value to generate, in seconds. If :py:obj:`None`, the generator
will continue indefinitely.
Yields
------
float
The next timestamp in the cycled intervals sequence.
Examples
--------
>>> log = SimulatedPolarizationLogs()
>>> list(log.cycled_intervals(intervals=[1.0, 2.0, 1.5], dead_time=0.0, upper_bound=10.0))
[0.0, 1.0, 3.0, 4.5, 5.5, 7.5, 9.0]
"""
accumulated, elapsed, i = [], 0.0, 0
while elapsed <= upper_bound if upper_bound is not None else True:
if elapsed >= dead_time:
yield elapsed
if i < len(intervals):
accumulated.append(intervals[i])
i += 1
else:
accumulated.append(intervals[0])
i = 1
# timestamps are computed with math.fsum over all accumulated intervals rather than by
# incremental addition, because summing many small floats incrementally causes rounding errors to
# drift (e.g. yielding 0.19999999 instead of 0.2).
elapsed = fsum(accumulated)
[docs]
def binary_pulse(
self,
interval: float,
alive_duration: float,
dead_time: Optional[float] = 0.0,
upper_bound: Optional[float] = None,
) -> Generator[float, None, None]:
"""
Generate a sequence of timestamps with a binary pulse pattern, starting at or later than dead_time.
The timestamps alternate between the start and end of a veto period, which is centered within each interval.
Parameters
----------
interval : float
The time interval between consecutive pulses, in seconds.
alive_duration : float
The duration of the period above the zero baseline, in seconds. Must be less than `interval`.
dead_time: float
The initial time period, in seconds, during which no times are generated. Defaults to 0.0.
upper_bound : float, optional
The maximum time value to generate, in seconds. If None, the generator will continue indefinitely.
Yields
------
float
The next timestamp in the binary pulse sequence.
Examples
--------
>>> log = SimulatedPolarizationLogs()
>>> list(log.binary_pulse(interval=3.0, alive_duration=1.0, upper_bound=10))
[0, 2.5, 3.5, 5.5, 6.5, 8.5, 9.5]
"""
if not alive_duration < interval:
raise ValueError("Veto duration must be less than the interval")
elapsed, latest_pulse_time, veto_half, continue_while = 0.0, interval, alive_duration / 2, True
if dead_time == 0.0:
yield elapsed
while continue_while:
for elapsed in [latest_pulse_time - veto_half, latest_pulse_time + veto_half]:
if (upper_bound is None) or (elapsed <= upper_bound):
if elapsed >= dead_time:
yield elapsed
else:
continue_while = False # exit the outer while loop
break # exit the immediate `for` loop
latest_pulse_time += interval
[docs]
def times_generator(self, pv_name: str, **options: dict) -> Optional[Generator[float, None, None]]:
"""
Create a generator that yields time points
This method selects the appropriate time generator function (e.g., `heartbeat` or `binary_pulse`)
based on the PV name and its associated generator specifications. Additional options can be passed
to override or extend the generator specifications.
Parameters
----------
pv_name : str
The name of the process variable (e.g., 'PolarizerFlipper', 'PolarizerVeto', etc.).
**options : dict
Additional keyword arguments to override or extend the generator's default arguments.
Raises
------
KeyError
If the provided PV name does not match any known process variable.
AttributeError
If the generator function associated with the PV name is not found.
Examples
--------
>>> logs = SimulatedPolarizationLogs(
... polarizer_flipper=TimesGeneratorSpecs("heartbeat", {"interval": 1.0}),
... polarizer_veto=TimesGeneratorSpecs("binary_pulse", {"interval": 1.0, "alive_duration": 0.4})
... )
>>> list(logs.times_generator(PV_POLARIZER_FLIPPER, upper_bound=6.3))
[0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
>>> list(logs.times_generator(PV_POLARIZER_VETO, upper_bound=6.3))
[0.0, 0.8, 1.2, 1.8, 2.2, 3.8, 4.2, 4.8, 5.2, 5.8, 6.2]
>>> logs.times_generator(PV_ANALYZER_FLIPPER)
None
>>> logs.times_generator(PV_ANALYZER_VETO)
None
"""
# conversion between PV name and class field
converter = {
PV_POLARIZER_FLIPPER: self.polarizer_flipper,
PV_POLARIZER_VETO: self.polarizer_veto,
PV_ANALYZER_FLIPPER: self.analyzer_flipper,
PV_ANALYZER_VETO: self.analyzer_veto,
}
specs_field = converter[pv_name]
if specs_field is None:
return None
generator_function = getattr(self, specs_field.name)
kwargs = {**specs_field.kwargs, **options}
return generator_function(**kwargs)
[docs]
def inject(self, input_workspace: MantidWorkspace):
"""
Injects simulated log data into a Mantid workspace.
This method adds polarizer and analyzer values as single-valued logs, and generates time-series logs
for flippers and veto process variables based on their associated time-generator specifications.
Values for flipper and veto time-series are either 0 or 1, and always start with 0 for simplicity.
Parameters
----------
input_workspace : MantidWorkspace
The Mantid workspace into which the simulated logs will be injected.
Raises
------
AttributeError
If the workspace does not contain required sample log entries like `run_start` or `duration`.
Examples
--------
>>> workspace = CreateSingleValuedWorkspace(OutputWorkspace="example")
>>> sample_logs = SampleLogs(workspace)
>>> sample_logs.insert("start_time", "2023-10-01T00:00:00")
>>> sample_logs.insert("duration", 300)
>>> logs = SimulatedPolarizationLogs(
... polarizer=1,
... polarizer_flipper=TimesGeneratorSpecs("heartbeat", {"interval": 1.0}),
... polarizer_veto=TimesGeneratorSpecs("binary_pulse", {"interval": 2.0, "alive_duration": 0.5})
... )
>>> logs.inject(workspace)
"""
sample_logs = SampleLogs(input_workspace)
# Retrieve the run start time and duration from the sample logs, handling potential attribute differences.
try:
run_start = sample_logs.run_start.value
except AttributeError:
run_start = sample_logs.start_time.value
duration: float = sample_logs.duration.value # in seconds
# insert polarizer and analyzer types
sample_logs.insert(name=PV_POLARIZER, value=self.polarizer)
sample_logs.insert(name=PV_ANALYZER, value=self.analyzer)
# insert the time series
for pv in [PV_POLARIZER_FLIPPER, PV_POLARIZER_VETO, PV_ANALYZER_FLIPPER, PV_ANALYZER_VETO]:
times = self.times_generator(pv, upper_bound=duration)
if times is None: # no time generator specifications for this PV
continue
times = list(times) # run the generator to get all times
values = [i % 2 for i in range(len(times))] # Alternating zeros and ones
sample_logs.insert_time_series(name=pv, start_time=run_start, elapsed_times=times, values=values)