#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Date : 2026-04-26
# Author : Lancelot PINCET
# GitHub : https://github.com/LancelotPincet
# Library : funcLP
# Module : JointFunction
"""
This class defines a way to apply joint models between various models.
"""
# %% Libraries
from corelp import prop, selfkwargs
from funclp import CudaReference, Parameter, JointChannel
import inspect
import numpy as np
import numba as nb
# %% Joint function
[docs]
class JointFunction(CudaReference):
'''Composite function used to jointly fit several independent channels.
Parameters
----------
channels : list
Functions or JointChannel instances.
shared_parameters : dict, optional
Mapping from joint parameter name to per-channel child parameter names.
Use None for channels that do not share this parameter.
Example: {'x0': ['mux', None, 'x0'], 'sigma': [None, 'sig', 'sig']}.
shared_variables : dict, optional
Mapping from joint variable name to per-channel child variable names.
Use None for channels that do not use this joint variable.
Example: {'x': ['x', 'xx'], 'y': ['y', 'yy'], 'z': ['z', None]}.
kwargs : dict
Attributes to change after construction.
Notes
-----
Data inputs are intentionally kept channel-local. They are passed to
prepare_fit_inputs as one data block per channel, because different channels
may have different data input names and shapes.
Affine transforms are fixed and applied on variables, not parameters. Since
affine coefficients are not fitted parameters, no derivatives are generated
for them.
'''
@prop()
def name(self):
return self.__class__.__name__
def __init__(self, channels, shared_parameters=None, shared_variables=None, **kwargs):
self.channels = [channel if isinstance(channel, JointChannel) else JointChannel(channel) for channel in channels]
if len(self.channels) < 2:
raise ValueError('JointFunction needs at least two channels')
self.shared_parameters = {} if shared_parameters is None else dict(shared_parameters)
self.shared_variables = {} if shared_variables is None else dict(shared_variables)
self._setup_channels()
self._setup_variables_and_data()
self._setup_parameters()
self._setup_constants()
self._setup_fit_ufunc()
self._build_kernels()
selfkwargs(self, kwargs)
self._sync_channels()
# ---------------------------------------------------------------------
# Setup
# ---------------------------------------------------------------------
def _setup_channels(self):
prefixes = set()
for pos, channel in enumerate(self.channels):
channel.prefix = channel.prefix or f'ch{pos}'
if not channel.prefix.isidentifier():
raise ValueError(f'{channel.prefix} is not a valid parameter prefix')
if channel.prefix in prefixes:
raise ValueError(f'Duplicate channel prefix {channel.prefix}')
prefixes.add(channel.prefix)
def _setup_variables_and_data(self):
self._validate_shared_variables()
self._channel_variable_exprs = []
variable_names = ['channel']
for key in self.shared_variables.keys():
if not key.isidentifier():
raise ValueError(f'{key} is not a valid joint variable name')
_unique_append(variable_names, key)
for pos, channel in enumerate(self.channels):
expressions = {}
for cname in channel.function.variables:
expressions[cname] = self._mapped_variable_expression(pos, cname)
self._channel_variable_exprs.append(expressions)
self.variables = variable_names
# Data is channel-local. Each data input receives a unique joint name,
# so the kernels can still have one flat input signature.
self.data = []
self._channel_data_names = []
for channel in self.channels:
channel_names = {}
for dname in channel.function.data:
joint_name = f'{channel.prefix}_{dname}'
self.data.append(joint_name)
channel_names[dname] = joint_name
self._channel_data_names.append(channel_names)
def _validate_shared_variables(self):
n = len(self.channels)
normalized = {}
for joint_name, mapping in self.shared_variables.items():
mapping = _as_list(mapping, n, f'shared_variables[{joint_name!r}]')
normalized[joint_name] = mapping
for pos, child_name in enumerate(mapping):
if child_name is None:
continue
if child_name not in self.channels[pos].function.variables:
raise ValueError(f'Variable {child_name} from shared variable {joint_name} is missing from channel {pos}')
self.shared_variables = normalized
for pos, channel in enumerate(self.channels):
mapped = {child for values in self.shared_variables.values() for i, child in enumerate(values) if i == pos and child is not None}
for group in channel.affine.keys():
for child_name in group:
if child_name not in channel.function.variables:
raise ValueError(f'Affine variable {child_name} is missing from channel {pos}')
if child_name not in mapped:
raise ValueError(f'Affine variable {child_name} in channel {pos} must be mapped by shared_variables')
def _setup_parameters(self):
self._validate_shared_parameters()
names = []
specs = {}
self._shared_parameter_sources = {}
for joint_name, mapping in self.shared_parameters.items():
source = self._shared_source(joint_name, mapping)
specs[joint_name] = Parameter(
joint_name,
source.default,
estimate=source.estimate,
bounds=source.bounds,
fit=source.fit,
)
names.append(joint_name)
self._shared_parameter_sources[joint_name] = source
self._channel_parameter_names = []
self._channel_shared_parameter_names = []
for pos, channel in enumerate(self.channels):
local_names = {}
shared_names = {}
linked_child_params = self._linked_child_parameters(pos)
for pname, value in channel.function.parameters.items():
if pname in linked_child_params:
shared_names[pname] = linked_child_params[pname]
continue
name = f'{channel.prefix}_{pname}'
spec = channel.function.__class__.function.parameter_specs[pname]
specs[name] = Parameter(
name,
value,
estimate=spec.estimate,
bounds=spec.bounds,
fit=getattr(channel.function, f'{pname}_fit'),
)
names.append(name)
local_names[pname] = name
self._channel_parameter_names.append(local_names)
self._channel_shared_parameter_names.append(shared_names)
self._parameters_names = names
self._parameter_specs = specs
for key in self._parameters_names:
setattr(self, key, convert(specs[key].default))
setattr(self, f'{key}_min', -np.float32(np.inf) if specs[key].bounds[0] is None else specs[key].bounds[0])
setattr(self, f'{key}_max', +np.float32(np.inf) if specs[key].bounds[1] is None else specs[key].bounds[1])
setattr(self, f'{key}_fit', bool(specs[key].fit))
def _validate_shared_parameters(self):
n = len(self.channels)
normalized = {}
for joint_name, mapping in self.shared_parameters.items():
if not joint_name.isidentifier():
raise ValueError(f'{joint_name} is not a valid shared parameter name')
mapping = _as_list(mapping, n, f'shared_parameters[{joint_name!r}]')
normalized[joint_name] = mapping
found = False
for pos, child_name in enumerate(mapping):
if child_name is None:
continue
found = True
if child_name not in self.channels[pos].function.parameters:
raise ValueError(f'Parameter {child_name} from shared parameter {joint_name} is missing from channel {pos}')
if not found:
raise ValueError(f'Shared parameter {joint_name} is not linked to any channel')
self.shared_parameters = normalized
for pos, channel in enumerate(self.channels):
seen = {}
for joint_name, mapping in self.shared_parameters.items():
child_name = mapping[pos]
if child_name is None:
continue
if child_name in seen:
raise ValueError(f'Channel {pos} parameter {child_name} is linked by both {seen[child_name]} and {joint_name}')
seen[child_name] = joint_name
def _shared_source(self, joint_name, mapping):
for pos, child_name in enumerate(mapping):
if child_name is None:
continue
channel = self.channels[pos]
spec = channel.function.__class__.function.parameter_specs[child_name]
return Parameter(
joint_name,
channel.function.parameters[child_name],
estimate=spec.estimate,
bounds=spec.bounds,
fit=getattr(channel.function, f'{child_name}_fit'),
)
raise ValueError(f'Shared parameter {joint_name} was not found in any channel')
def _linked_child_parameters(self, pos):
linked = {}
for joint_name, mapping in self.shared_parameters.items():
child_name = mapping[pos]
if child_name is not None:
linked[child_name] = joint_name
return linked
def _setup_constants(self):
self._constant_names = []
self._constant_values = {}
self._channel_constant_names = []
for channel in self.channels:
channel_names = {}
for cname, value in channel.function.constants.items():
name = f'{channel.prefix}_{cname}'
self._constant_names.append(name)
self._constant_values[name] = value
channel_names[cname] = name
self._channel_constant_names.append(channel_names)
def _setup_fit_ufunc(self):
self.fit_ufunc = _FitUfunc(self.variables, self.data, self._parameters_names, self._parameter_specs)
# ---------------------------------------------------------------------
# Public properties
# ---------------------------------------------------------------------
@property
def parameters(self):
return {key: getattr(self, key) for key in self._parameters_names}
@parameters.setter
def parameters(self, values):
for key, value in values.items():
setattr(self, key, value)
self._sync_channels()
@property
def constants(self):
return {key: self._constant_values[key] for key in self._constant_names}
@constants.setter
def constants(self, values):
self._constant_values.update(values)
@property
def nmodels(self):
shape = np.broadcast_shapes(*[np.shape(getattr(self, param, [])) for param in self.parameters])
if len(shape) > 1:
raise ValueError('Parameters cannot have more than 1 dimension')
return shape[0] if len(shape) == 1 else 0
def _sync_channels(self):
for pos, channel in enumerate(self.channels):
values = {}
values.update({pname: getattr(self, name) for pname, name in self._channel_parameter_names[pos].items()})
values.update({pname: getattr(self, name) for pname, name in self._channel_shared_parameter_names[pos].items()})
channel.function.parameters = values
# ---------------------------------------------------------------------
# Fit input preparation
# ---------------------------------------------------------------------
def _get_channel_inputs(self, args):
if len(args) == 0:
values = [(self._default_channel_variables(channel), {}) for channel in self.channels]
elif len(args) == 1 and isinstance(args[0], (list, tuple)) and len(args[0]) == len(self.channels):
values = [self._normalize_channel_input(pos, item) for pos, item in enumerate(args[0])]
else:
raise ValueError('Pass channel inputs as [(variables, data), ...] or [{' + "'x': x, ..." + '}, ...]')
for pos, (variables, data_inputs) in enumerate(values):
channel = self.channels[pos]
for vname in channel.function.variables:
if vname not in variables:
raise ValueError(f'Missing variable {vname} for channel {pos}')
for dname in channel.function.data:
if dname not in data_inputs:
raise ValueError(f'Missing data input {dname} for channel {pos}')
return values
def _normalize_channel_input(self, pos, item):
channel = self.channels[pos]
if item is None:
raise ValueError(f'Missing inputs for channel {pos}')
if isinstance(item, dict):
variables = {key: item[key] for key in channel.function.variables if key in item}
data_inputs = {key: item[key] for key in channel.function.data if key in item}
return variables, data_inputs
if isinstance(item, (list, tuple)) and len(item) == 2 and isinstance(item[0], dict):
variables = dict(item[0])
data_inputs = dict(item[1] or {})
return variables, data_inputs
if isinstance(item, (list, tuple)):
nvars = len(channel.function.variables)
ndata = len(channel.function.data)
if len(item) != nvars + ndata:
raise ValueError(f'Channel {pos} expects {nvars} variable inputs and {ndata} data inputs')
variables = dict(zip(channel.function.variables, item[:nvars]))
data_inputs = dict(zip(channel.function.data, item[nvars:]))
return variables, data_inputs
raise ValueError(f'Invalid input format for channel {pos}')
def _default_channel_variables(self, channel):
if not channel.variables:
return None
reverse = {}
for joint_name, child_name in channel.variables.items():
reverse[child_name] = joint_name
return reverse
def _channel_point_shape(self, channel, variables):
shapes = [tuple(np.shape(variables[vname])) for vname in channel.function.variables]
if len(shapes) == 0:
return ()
return np.broadcast_shapes(*shapes)
def _reshape_model_points(self, data, point_shape, label):
if data.shape == point_shape:
return data.reshape((1, -1))
if data.shape[1:] == point_shape:
return data.reshape((data.shape[0], -1))
raise ValueError(f'{label} shape {data.shape} does not match variables shape {point_shape}')
def _reshape_data_input(self, data, nmodels, point_shape, label):
if data.shape == point_shape:
return data.reshape((1, -1))
if data.shape[1:] == point_shape:
if data.shape[0] != nmodels:
raise ValueError(f'{label} has {data.shape[0]} models but expected {nmodels}')
return data.reshape((data.shape[0], -1))
raise ValueError(f'{label} shape {data.shape} does not match variables shape {point_shape}')
def _prepare_weights(self, weights, raw_arrays, channel_inputs, xp):
if not isinstance(weights, (list, tuple)):
return weights
if len(weights) != len(self.channels):
raise ValueError('weights must be scalar or one entry per channel')
arrays = []
for pos, (weight, raw, inputs) in enumerate(zip(weights, raw_arrays, channel_inputs)):
variables, _ = inputs
point_shape = self._channel_point_shape(self.channels[pos], variables)
weight = xp.asarray(weight)
if weight.size == 1:
arrays.append(xp.full(raw.shape, weight.reshape(())))
continue
if weight.shape == point_shape:
weight = xp.broadcast_to(weight, point_shape).reshape((1, -1))
elif weight.shape[1:] == point_shape:
weight = weight.reshape((weight.shape[0], -1))
else:
raise ValueError(f'Weight shape {weight.shape} does not match variables shape {point_shape}')
arrays.append(xp.broadcast_to(weight, raw.shape))
return xp.concatenate(arrays, axis=1)
def _get_xp(self, raw_data, channel_inputs, weights):
try:
import cupy as cp
except ImportError:
cp = None
if cp is None:
return np
arrays = list(raw_data)
for variables, data_inputs in channel_inputs:
arrays += list(variables.values())
arrays += list(data_inputs.values())
if isinstance(weights, (list, tuple)):
arrays += list(weights)
else:
arrays.append(weights)
return cp if any(isinstance(arr, cp.ndarray) for arr in arrays if arr is not None) else np
# ---------------------------------------------------------------------
# Kernel construction
# ---------------------------------------------------------------------
def _build_kernels(self):
inputs = ', '.join(self.variables + self.data + self._parameters_names + self._constant_names)
self.cpukernel_function = self._build_scalar_kernel(inputs, False, 'function')
self.gpukernel_function = self._build_scalar_kernel(inputs, True, 'function')
for pname in self._parameters_names:
setattr(self, f'cpukernel_d_{pname}', self._build_scalar_kernel(inputs, False, pname))
setattr(self, f'gpukernel_d_{pname}', self._build_scalar_kernel(inputs, True, pname))
self.cpu_function = self._build_array_kernel(inputs, False)
self.gpu_function = self._build_array_kernel(inputs, True)
self.cpu_jacobian = self._build_jacobian_kernel(inputs, False)
self.gpu_jacobian = self._build_jacobian_kernel(inputs, True)
def _build_scalar_kernel(self, inputs, cuda, derivative):
body = self._scalar_body(derivative)
string = f'def func({inputs}):\n{body}\n return 0.0\n'
glob = self._kernel_globals(cuda)
loc = {}
exec(string, glob, loc)
if cuda:
return nb.cuda.jit(device=True)(loc['func'])
return nb.njit(nogil=True)(loc['func'])
def _build_array_kernel(self, inputs, cuda):
call = self._indexed_call()
if cuda:
string = f'''
def func({inputs}, out, ignore):
nmodels, npoints = out.shape
model, point = nb.cuda.grid(2)
if model < nmodels and point < npoints and not ignore[model]:
out[model, point] = scalar({call})
'''
glob = {'nb': nb, 'scalar': self.gpukernel_function}
loc = {}
exec(string, glob, loc)
return nb.cuda.jit()(loc['func'])
string = f'''
def func({inputs}, out, ignore):
nmodels, npoints = out.shape
for model in nb.prange(nmodels):
if ignore[model]:
continue
for point in range(npoints):
out[model, point] = scalar({call})
'''
glob = {'nb': nb, 'scalar': self.cpukernel_function}
loc = {}
exec(string, glob, loc)
return nb.njit(parallel=True, nogil=True)(loc['func'])
def _build_jacobian_kernel(self, inputs, cuda):
call = self._indexed_call()
if cuda:
derivatives = '\n'.join([f''' if bool2fit[{pos}]:\n jacobian[model, point, count] = d_{key}({call})\n count += 1''' for pos, key in enumerate(self._parameters_names)])
string = f'''
def func({inputs}, jacobian, bool2fit, ignore):
nmodels, npoints, nparams = jacobian.shape
model, point = nb.cuda.grid(2)
if model < nmodels and point < npoints and not ignore[model]:
count = 0
{derivatives}
'''
glob = {'nb': nb}
for key in self._parameters_names:
glob[f'd_{key}'] = getattr(self, f'gpukernel_d_{key}')
loc = {}
exec(string, glob, loc)
return nb.cuda.jit()(loc['func'])
derivatives = '\n'.join([f''' if bool2fit[{pos}]:\n jacobian[model, point, count] = d_{key}({call})\n count += 1''' for pos, key in enumerate(self._parameters_names)])
string = f'''
def func({inputs}, jacobian, bool2fit, ignore):
nmodels, npoints, nparams = jacobian.shape
for model in nb.prange(nmodels):
if ignore[model]:
continue
for point in range(npoints):
count = 0
{derivatives}
'''
glob = {'nb': nb}
for key in self._parameters_names:
glob[f'd_{key}'] = getattr(self, f'cpukernel_d_{key}')
loc = {}
exec(string, glob, loc)
return nb.njit(parallel=True, nogil=True)(loc['func'])
def _indexed_call(self):
variables = [f'{key}[point]' for key in self.variables]
data = [f'{key}[model, point]' for key in self.data]
parameters = [f'{key}[model]' for key in self._parameters_names]
return ', '.join(variables + data + parameters + self._constant_names)
def _scalar_body(self, derivative):
body = ''
for pos, channel in enumerate(self.channels):
prefix = 'if' if pos == 0 else 'elif'
body += f' {prefix} channel == {pos}:\n'
expr = self._channel_expression(pos, derivative)
body += f' return {expr}\n'
return body
def _channel_expression(self, pos, derivative):
if derivative == 'function':
return f'f{pos}({self._child_inputs(pos)})'
# Shared joint parameters: derivative is the child derivative for each
# channel where the parameter is linked, otherwise zero.
if derivative in self.shared_parameters:
child_name = self.shared_parameters[derivative][pos]
if child_name is None:
return '0.0'
return f'd{pos}_{child_name}({self._child_inputs(pos)})'
# Channel-local parameters.
for child_name, joint_name in self._channel_parameter_names[pos].items():
if derivative == joint_name:
return f'd{pos}_{child_name}({self._child_inputs(pos)})'
return '0.0'
def _child_inputs(self, pos):
channel = self.channels[pos]
values = []
values += [self._channel_variable_exprs[pos][key] for key in channel.function.variables]
values += [self._channel_data_names[pos][key] for key in channel.function.data]
for pname in channel.function.parameters.keys():
if pname in self._channel_shared_parameter_names[pos]:
values.append(self._channel_shared_parameter_names[pos][pname])
else:
values.append(self._channel_parameter_names[pos][pname])
values += [self._channel_constant_names[pos][key] for key in channel.function.constants.keys()]
return ', '.join(values)
def _mapped_variable_expression(self, pos, child_name):
joint_name = None
for key, mapping in self.shared_variables.items():
if mapping[pos] == child_name:
joint_name = key
break
if joint_name is None:
raise ValueError(f'Channel {pos} variable {child_name} is not mapped by shared_variables')
return self._affine_variable_expression(pos, child_name, joint_name)
def _affine_variable_expression(self, pos, child_name, fallback_joint_name):
channel = self.channels[pos]
for group, matrix in channel.affine.items():
if child_name not in group:
continue
row = group.index(child_name)
terms = []
for col, other_child_name in enumerate(group):
joint_name = self._joint_variable_for_child(pos, other_child_name)
coeff = float(matrix[row, col])
if coeff == 0:
continue
terms.append(f'({coeff!r} * {joint_name})')
shift = float(matrix[row, len(group)])
if shift != 0:
terms.append(f'{shift!r}')
return '(' + ' + '.join(terms or ['0.0']) + ')'
return fallback_joint_name
def _joint_variable_for_child(self, pos, child_name):
for joint_name, mapping in self.shared_variables.items():
if mapping[pos] == child_name:
return joint_name
raise ValueError(f'Channel {pos} variable {child_name} is not mapped by shared_variables')
def _kernel_globals(self, cuda):
glob = {'nb': nb}
for pos, channel in enumerate(self.channels):
glob[f'f{pos}'] = channel.function.gpukernel_function if cuda else channel.function.cpukernel_function
for pname in channel.function.parameters.keys():
glob[f'd{pos}_{pname}'] = getattr(channel.function, f'gpukernel_d_{pname}' if cuda else f'cpukernel_d_{pname}')
return glob
def _as_list(value, n, name):
if value is None:
return [None] * n
if not isinstance(value, (list, tuple)) or len(value) != n:
raise ValueError(f'{name} must be a list/tuple of length {n}')
return list(value)
def _unique_append(values, value):
if value not in values:
values.append(value)
class _FitUfunc:
def __init__(self, variables, data, parameters, parameter_specs):
self.variables = variables
self.data = data
self.parameters = parameters
self.parameter_specs = parameter_specs
signature_parameters = [inspect.Parameter(key, inspect.Parameter.POSITIONAL_ONLY) for key in variables]
signature_parameters += [inspect.Parameter(key, inspect.Parameter.POSITIONAL_ONLY) for key in data]
signature_parameters += [inspect.Parameter(key, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=parameter_specs[key].default) for key in parameters]
self.signature = inspect.Signature(signature_parameters)
def convert(value) :
try :
dtype = value.dtype
if np.issubdtype(dtype, np.bool_) :
return value.astype(np.bool_)
elif np.issubdtype(dtype, np.integer) :
return value.astype(np.int32)
elif np.issubdtype(dtype, np.floating) :
return value.astype(np.float32)
else :
raise TypeError(f'Parameter cannot have {dtype} dtype')
except AttributeError:
if isinstance(value, bool) or isinstance(value, np.bool_) :
return np.bool_(value)
elif isinstance(value, int) or isinstance(value, np.integer) :
return np.int32(value)
elif isinstance(value, float) or isinstance(value, np.floating) :
return np.float32(value)
else :
raise TypeError(f'Parameter cannot have {type(value)} dtype')
# %% Test function run
if __name__ == "__main__":
from corelp import test
test(__file__)