mlp
Multi-layer perceptron discriminator for permutation weighting.
MLPDiscriminator(hidden_dims=None, activation='relu')
Bases: BaseDiscriminator
Multi-layer perceptron (MLP) discriminator using A, X, and A*X interactions.
The MLP processes concatenated features [A, X, A*X] through configurable hidden layers with specified activation functions, outputting a scalar logit.
This provides more expressive power than linear discriminators for capturing complex relationships between treatments and covariates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_dims
|
list[int]
|
List of hidden layer sizes. Default is [64, 32] |
None
|
activation
|
(relu, tanh, elu, sigmoid)
|
Activation function to use between layers. Default is 'relu' |
'relu'
|
Examples:
>>> from stochpw.models import MLPDiscriminator
>>> import jax
>>>
>>> # Default: 2-layer MLP with ReLU
>>> discriminator = MLPDiscriminator()
>>>
>>> # Custom: 3-layer MLP with tanh
>>> discriminator = MLPDiscriminator(hidden_dims=[128, 64, 32], activation="tanh")
>>>
>>> params = discriminator.init_params(jax.random.PRNGKey(0), d_a=1, d_x=3)
Source code in src/stochpw/models/mlp.py
apply(params, a, x, ax)
Compute MLP discriminator logits using A, X, and A*X.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params
|
dict
|
Parameters with key 'layers' containing list of layer dicts |
required |
a
|
(Array, shape(batch_size, d_a) or (batch_size,))
|
Treatment assignments |
required |
x
|
(Array, shape(batch_size, d_x))
|
Covariates |
required |
ax
|
(Array, shape(batch_size, d_a * d_x))
|
Pre-computed first-order interactions A ⊗ X |
required |
Returns:
| Name | Type | Description |
|---|---|---|
logits |
(Array, shape(batch_size))
|
Discriminator logits for p(C=1 | a, x) |
Source code in src/stochpw/models/mlp.py
init_params(rng_key, d_a, d_x)
Initialize MLP discriminator parameters.
Uses He initialization for ReLU-family activations and Xavier initialization for tanh/sigmoid activations. Biases are initialized to zero.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rng_key
|
Array
|
Random key for parameter initialization |
required |
d_a
|
int
|
Dimension of treatment vector |
required |
d_x
|
int
|
Dimension of covariate vector |
required |
Returns:
| Name | Type | Description |
|---|---|---|
params |
dict
|
Dictionary with key 'layers' containing a list of layer dicts, each with keys 'w' (weight matrix) and 'b' (bias vector) |