Chapter 17
14 min read
Section 70 of 121

The Two-Task Closed Form

Inverse-Gradient Balancing: The Idea

One Formula, Three Doors

On a TCP-congested network link, two flows compete for one pipe. Telecom engineers in the 1990s argued for years over the ‘right’ way to share bandwidth. Three camps emerged:

  • Equality — give each flow the same throughput.
  • Max-min — lift the slowest flow as high as possible.
  • Minimum variance — minimise the spread of throughputs.

For two flows on a single link, all three philosophies produce the same allocation. Different doors, same room. The K=2 GABA closed form is the gradient-space twin of that allocation: three apparently different objectives all collapse onto the same λ\lambda^*.

Why it matters. If the closed form were only the answer to ONE objective, you might worry it's an accident of formulation. The fact that λ\lambda^* is simultaneously argmin of three independent objectives is evidence that the formula is structural, not arbitrary.

The Closed Form

For two tasks {rul,health}\{rul, health\} with shared-backbone gradient norms grul,ghealth0g_{\text{rul}}, g_{\text{health}} \geq 0 (and not both zero), GABA assigns:

λrul=ghealthgrul+ghealth,λhealth=grulgrul+ghealth\lambda^*_{\text{rul}} = \frac{g_{\text{health}}}{g_{\text{rul}} + g_{\text{health}}}, \quad \lambda^*_{\text{health}} = \frac{g_{\text{rul}}}{g_{\text{rul}} + g_{\text{health}}}

The numerators are swapped — that is the inverse-proportional structure. The denominators coincide, guaranteeing λrul+λhealth=1\lambda^*_{\text{rul}} + \lambda^*_{\text{health}} = 1. Plugging the realistic numbers from §12.3\S 12.3: grul=5.0g_{\text{rul}} = 5.0, ghealth=0.01g_{\text{health}} = 0.01 yields λrul=0.001996\lambda^*_{\text{rul}} = 0.001996, λhealth=0.998004\lambda^*_{\text{health}} = 0.998004.

Derivation 1: Lagrangian on the Equality Constraint

Set up the constrained problem. Variables λ1,λ20\lambda_1, \lambda_2 \geq 0 on the simplex λ1+λ2=1\lambda_1 + \lambda_2 = 1 with the equal-contribution requirement λ1g1=λ2g2\lambda_1 g_1 = \lambda_2 g_2. Form the Lagrangian:

L(λ1,λ2,μ,ν)=12(λ1g1λ2g2)2+μ(λ1+λ21)+ν(λ1g1λ2g2)\mathcal{L}(\lambda_1, \lambda_2, \mu, \nu) = \tfrac{1}{2}(\lambda_1 g_1 - \lambda_2 g_2)^2 + \mu\,(\lambda_1 + \lambda_2 - 1) + \nu\,(\lambda_1 g_1 - \lambda_2 g_2)

Stationarity at the equality-constraint solution forces λ1g1=λ2g2\lambda_1 g_1 = \lambda_2 g_2. Combine with λ1+λ2=1\lambda_1 + \lambda_2 = 1:

λ1g1=(1λ1)g2    λ1(g1+g2)=g2    λ1=g2g1+g2\lambda_1 g_1 = (1 - \lambda_1) g_2 \;\Longrightarrow\; \lambda_1 (g_1 + g_2) = g_2 \;\Longrightarrow\; \lambda_1 = \frac{g_2}{g_1 + g_2}

Symmetric for λ2\lambda_2. The KKT non-negativity multipliers vanish because the unconstrained solution already satisfies λi0\lambda_i \geq 0 whenever gi0g_i \geq 0. The closed form is the unique KKT point.

Derivation 2: Max-Min Fairness

Different setup, same answer. Forget the equality constraint — pose the LP:

maxλ1,λ20,λ1+λ2=1min(λ1g1,λ2g2)\max_{\lambda_1, \lambda_2 \geq 0,\, \lambda_1 + \lambda_2 = 1} \quad \min(\lambda_1 g_1,\, \lambda_2 g_2)

The objective is a piecewise-linear concave function of λ1\lambda_1:

  • For small λ1\lambda_1: min=λ1g1\min = \lambda_1 g_1 (increasing).
  • For large λ1\lambda_1: min=(1λ1)g2\min = (1 - \lambda_1) g_2 (decreasing).

The maximum sits at the kink where the two pieces cross: λ1g1=(1λ1)g2\lambda_1 g_1 = (1 - \lambda_1) g_2 — the same equation as before, with the same solution.

The same equation appears for a third reason. Minimising Var(c1,c2)=14(c1c2)2\mathrm{Var}(c_1, c_2) = \tfrac{1}{4}(c_1 - c_2)^2 is exactly 14\tfrac{1}{4} times Derivation 1's squared-gap objective. Same argmin. Three doors, one room.

Interactive: Three Objectives, One Optimum

The plot below evaluates the three objectives along λrul[0,1]\lambda_{\text{rul}} \in [0, 1], normalises each to [0,1][0, 1], and overlays them. The amber dashed line marks λ=ghealth/(grul+ghealth)\lambda^* = g_{\text{health}} / (g_{\text{rul}} + g_{\text{health}}). All three curves peak there, regardless of the gradient ratio.

Loading closed-form visualizer…
Try this. Slide grulg_{\text{rul}} from 0.001 up to 100. The amber line walks across the plot, but all three coloured curves continue to peak together. Now move the perturbation slider to +10%+10\%: the red dashed line shifts slightly to the left, reporting the relative change in λ\lambda^*. Notice that a 10% gradient measurement error produces a less-than-10% shift in λ\lambda^* — the closed form is stable.

Python: Numerical Optimisers Recover the Closed Form

Write the three objectives as Python functions, hand them to a generic 1-D bounded minimiser, and confirm they all return the same λ\lambda^* to ten decimal places. The point is to corroborate the analytic derivations with bit-level numerical evidence.

Three objectives, one numerical optimum
🐍closed_form_objectives.py
1Module docstring

States the central claim of this file: three apparently unrelated optimisation objectives — squared gap, max-min fairness, and minimum variance — all collapse onto the SAME closed-form value of lambda. The rest of the file proves this numerically by handing each objective to SciPy and checking that all three return the same answer to 10 decimals.

EXECUTION STATE
→ why a triple-check? = Mathematically, equal-contribution, max-min, and min-variance for K=2 share a unique optimum. We're corroborating the algebraic proofs (Derivations 1 and 2 in the prose above) with bit-level numerical evidence. If any of the three SciPy runs disagreed, either the algebra is wrong OR SciPy is failing — both are worth knowing.
3import numpy as np

NumPy is Python's numerical computing workhorse. It provides ndarray (an N-dimensional, dense, contiguous float array) along with vectorised math and reductions implemented in C. We need only two pieces of NumPy in this file: np.array() to build the small contribution vector inside variance_contribution(), and the .var() reduction to compute its population variance.

EXECUTION STATE
📚 numpy = Numerical computing library. Provides ndarray (the fast typed array), broadcasting, linear algebra, random number generation, and reductions like .sum(), .mean(), .var(). All the heavy lifting runs in C, not Python loops.
as np = Universal community alias. Lets us write np.array() rather than numpy.array(). Same convention as 'import pandas as pd' or 'import torch as torch'.
→ why import at all here? = Method 3 (variance_contribution) is the only function that actually needs NumPy — methods 0, 1, 2 work on plain Python floats. We import once at the top so all three objectives can be defined in the same file.
4from scipy.optimize import minimize_scalar

SciPy's 1-D scalar minimiser. The 'from … import name' form pulls only the one function we need into the namespace, so we write minimize_scalar(...) rather than scipy.optimize.minimize_scalar(...). All three of our objectives have the same shape (R → R) so this single function handles all of them.

EXECUTION STATE
📚 scipy.optimize.minimize_scalar(fun, bounds, method) = 1-D scalar minimiser. With method='bounded' it uses a bounded variant of Brent's method (parabolic interpolation + golden-section fallback) restricted to [a, b]. Returns an OptimizeResult; the argmin is in the .x attribute.
→ why scalar (not minimize)? = scipy.optimize.minimize is for vector-valued x and needs gradient/Hessian or finite-difference approximations. minimize_scalar is purpose-built for 1-D and is dramatically more reliable on small intervals. Our problem is 1-D after the simplex constraint is folded in, so minimize_scalar is the right tool.
→ why bounded? = Our search interval is [0, 1]. method='bounded' guarantees the optimiser never tries lambda < 0 or > 1, where our objectives have no physical meaning. method='brent' would happily wander outside.
7# Realistic FD002 numbers from §12.3

Comment that anchors the demo to the paper&apos;s measured numbers rather than contrived toy values. FD002 is the second sub-dataset of NASA C-MAPSS turbofan benchmark; §12.3 of this book reports the median per-parameter gradient norms on the shared backbone for the RUL regression head and the health-classification head.

EXECUTION STATE
→ why anchor to real numbers? = Toy values like (g_rul, g_health) = (1, 1) would make all three methods trivially equal at lam=0.5 and prove nothing. The whole point of the closed form is its behaviour when gradients are wildly imbalanced (here 500x). Using the real ratio surfaces the algorithm&apos;s most interesting regime.
8g_rul, g_health = 5.0, 0.01

Tuple-unpacking assignment — Python evaluates the right side as the tuple (5.0, 0.01), then binds g_rul to 5.0 and g_health to 0.01 in one statement. These are the two gradient-norm scalars that drive the entire experiment. The huge ratio (500x) is exactly the imbalance GABA was designed to fix.

EXECUTION STATE
g_rul = Float = 5.0. Median ‖∂L_RUL/∂θ_backbone‖ across the paper&apos;s n=4,120 backbone-parameter sample. RUL is a regression task with continuous targets ∈ [0, 125] cycles, so its gradients are unbounded.
g_health = Float = 0.01. Median ‖∂L_health/∂θ_backbone‖. Health is a 6-class softmax cross-entropy; its gradient norm is bounded by sqrt(K) ≈ 2.45 — and at convergence is much smaller.
→ tuple unpacking = Equivalent to: tmp = (5.0, 0.01); g_rul = tmp[0]; g_health = tmp[1]. Python supports it natively for any iterable RHS with matching length.
→ ratio g_rul / g_health = 5.0 / 0.01 = 500. The RUL gradient outweighs the health gradient 500-to-1. Without rebalancing, an SGD step on the joint loss is ≈99.8% RUL update and ≈0.2% health update — health learns essentially nothing.
9S = g_rul + g_health

Pre-compute the sum once. S appears in the denominator of the closed form (line 13) AND of both partial derivatives (lines 47 and 48), and inside every objective evaluation. Caching it avoids three redundant additions per call.

EXECUTION STATE
S = 5.0 + 0.01 = 5.01. The 'mass' of the gradient pair — total per-parameter-norm budget being divided between the two tasks.
→ micro-optimisation note = Inside SciPy&apos;s minimiser the objective is called ~30-50 times. Five extra additions matter little here, but the habit of hoisting loop-invariant computations is universal good practice.
12# Method 0: closed form (analytic)

Section header. Method 0 is the algebraic ground truth derived in §17.2; methods 1, 2, 3 below are numerical alternatives that should — if the algebra is correct — converge to the same value.

13lambda_closed = g_health / S

The K=2 closed form derived two ways in the prose above: lambda_rul* = g_health / (g_rul + g_health). One division. No iteration. No tuning. This is the value every numerical method on the page must reproduce.

EXECUTION STATE
g_health / S = 0.01 / 5.01. Division of two Python floats — IEEE-754 64-bit result.
→ walk through arithmetic = Numerator: 0.01 Denominator: 5.01 Quotient: 0.01 / 5.01 = 0.0019960079840319362
lambda_closed = 0.0019960079840319362 (printed as 0.0019960080 to 10 decimals). The TARGET every numerical method must hit.
→ intuition = lambda_rul is TINY (~0.2%) because RUL&apos;s gradient is already 500x larger than health&apos;s. To equalise the effective contributions (lam·g_rul = (1-lam)·g_health), the RUL weight must shrink the dominant gradient down to the tiny gradient&apos;s level — hence the down-weight by ~500x.
16# Method 1: minimise the squared gap between contributions

Section header. Sets up the first numerical objective: J1(lam) = (c_rul - c_health)^2. This equals zero IFF the two contributions are exactly equal, so the global minimum sits at the equal-contribution point — which is the closed form.

17def squared_gap(lam) → float

Defines the J1 objective as a Python function of one scalar lam. Form: J1(lam) = (lam·g_rul − (1−lam)·g_health)². Minimising drives the squared gap to zero, hence c_rul = c_health, hence lam = lambda_closed.

EXECUTION STATE
⬇ input: lam (float) = The proposed lambda_rul, in [0, 1]. lambda_health is implicitly 1 - lam (the simplex constraint is baked into the parametrisation, so we never need to track it as a separate variable).
→ lam purpose = This is the optimisation variable. SciPy will repeatedly call squared_gap(lam) at different lam values, looking for the one that minimises the return.
: float (annotation) = Type hint. Tells static checkers that lam is a Python float. Has no runtime effect — Python is duck-typed — but documents intent.
→ float = Return-type annotation. Documents that the function returns one Python float.
g_rul, g_health (closures) = Captured from enclosing module scope. Both are 5.0 and 0.01 respectively. Python looks them up at call time, not definition time, so changing them at the module level would change subsequent calls.
⬆ returns = Non-negative float. Equals 0 exactly when lam = lambda_closed; strictly > 0 everywhere else. Quadratic in lam, so the minimiser converges fast.
18docstring: J1(lam) = (lam * g_rul - (1 - lam) * g_health) ** 2

Triple-quoted docstring. Records the algebraic form of J1 directly inside the function so a reader hovering on squared_gap in an editor sees the formula. help(squared_gap) and IDE tooltips render this string.

19return (lam * g_rul - (1 - lam) * g_health) ** 2

Single-expression return. Computes c_rul − c_health, then squares the difference. Python evaluates left-to-right respecting standard operator precedence: multiplications first, then subtraction, then **2.

EXECUTION STATE
lam * g_rul = The effective RUL contribution c_rul. At lam=0.001996: 0.001996 * 5.0 = 0.00998.
(1 - lam) = Implicit lambda_health. At lam=0.001996: 1 - 0.001996 = 0.998004.
(1 - lam) * g_health = The effective health contribution c_health. At lam=0.001996: 0.998004 * 0.01 = 0.00998004.
(c_rul - c_health) = At the optimum: 0.00998 - 0.00998 ≈ 0. Anywhere else, non-zero.
** 2 = Python power operator. Squares the gap so the function is non-negative AND smooth (quadratic). The square also makes the gradient zero at the minimum, which Brent&apos;s method exploits for fast convergence.
→ walk through at lam=0.5 (mid-point) = c_rul = 0.5 * 5.0 = 2.5 c_health = 0.5 * 0.01 = 0.005 gap = 2.5 - 0.005 = 2.495 gap² = 6.225 — far from zero, so 0.5 is far from the optimum.
→ walk through at lam=0.001996 (closed form) = c_rul = 0.001996 * 5.0 = 0.00998 c_health = 0.998004 * 0.01 = 0.00998004 gap = 0.00998 - 0.00998004 ≈ 0 gap² ≈ 0 ✓
⬆ return = Non-negative float. 0 at the optimum; ~6.225 at lam=0.5.
22# Method 2: maximise the minimum contribution (max-min fairness)

Section header for objective 2. Max-min fairness, also known as the Rawlsian criterion, asks: 'pick the allocation that lifts the worst-off task as high as possible.' Famous in networking (proportional fairness), economics (egalitarian welfare), and political philosophy.

23def neg_min_contribution(lam) → float

Defines the J2 objective. Returns NEGATIVE of min(c_rul, c_health). Why negative? SciPy&apos;s minimize_scalar only minimises. To maximise the worst-off task, we minimise its negative — a standard sign-flip trick that turns any maximisation into a minimisation.

EXECUTION STATE
⬇ input: lam (float) = The proposed lambda_rul, in [0, 1].
→ lam purpose = Same role as in J1: the optimisation variable.
⬆ returns = Non-positive float. Equals -|c_min|. The most-negative value (= largest c_min) sits at the kink where c_rul = c_health — the same point as J1&apos;s minimum.
→ why is the optimum the same? = On [0, 1], c_rul = lam·g_rul (increasing) and c_health = (1-lam)·g_health (decreasing). min(c_rul, c_health) is the point-wise lower envelope: it rises from the left, peaks where the two lines cross (c_rul = c_health), then falls to the right. The peak IS the equal-contribution point.
24docstring: J2(lam) = -min(c_rul, c_health). minimise this == maximise the min.

Records the sign-flip trick in the function&apos;s own docstring so the negation is documented at the point where it&apos;s introduced.

25return -min(lam * g_rul, (1 - lam) * g_health)

Compute the two contributions inline, take the smaller with Python&apos;s built-in min(), then negate.

EXECUTION STATE
📚 min(a, b) = Python builtin. For two arguments: returns the smaller via plain comparison. For floats: pure scalar comparison, no broadcasting, no NumPy semantics.
⬇ arg 1: lam * g_rul = c_rul. The effective contribution of the RUL task at the current lam.
⬇ arg 2: (1 - lam) * g_health = c_health. The effective contribution of the health task at the current lam.
→ why min, not mean? = Max-min fairness specifically targets the WORST-OFF party. Using mean would just minimise the average and could starve a small task entirely. min is the egalitarian / Rawlsian choice — protect the most disadvantaged.
- (unary negate) = Python unary minus. Flips the sign so SciPy&apos;s minimiser maximises the min.
→ walk through at lam=0.001996 = c_rul = 0.00998 c_health = 0.00998 min(0.00998, 0.00998) = 0.00998 -0.00998 → returned. This is as negative as the function gets — i.e. the maximum of min().
→ walk through at lam=0.5 = c_rul = 2.5 c_health = 0.005 min(2.5, 0.005) = 0.005 (the smaller one, c_health) -0.005 → returned. Less negative than at the optimum, so further from the minimum.
⬆ return = Negative float. Most-negative value is -0.00998 at lam = lambda_closed.
28# Method 3: minimise variance of contributions

Section header for objective 3. Statistical-fairness framing: zero variance ⇔ equal contributions. For K=2 this is algebraically proportional to J1, but writing it as variance highlights the fairness intuition.

29def variance_contribution(lam) → float

Defines the J3 objective. Build the (2,) vector of contributions and return its population variance. Var = 0 IFF all entries are equal IFF c_rul = c_health.

EXECUTION STATE
⬇ input: lam (float) = The proposed lambda_rul, in [0, 1].
⬆ returns = Non-negative float. = 0 at the optimum.
→ why bother defining variance? = It&apos;s the same answer as J1 up to a constant factor (we prove this in line 32). Including it shows that statistical-fairness intuitions and equal-contribution intuitions are the SAME mathematical object for K=2.
30docstring: J3(lam) = Var([c_rul, c_health])

Records the J3 formula in the function docstring.

31c = np.array([lam * g_rul, (1 - lam) * g_health])

Build a length-2 NumPy array containing the two effective contributions, so we can call .var() on it. We could compute variance by hand for K=2, but using ndarray makes the code generalise trivially to K > 2.

EXECUTION STATE
📚 np.array(object, dtype=None) = Build an ndarray from any nested sequence. dtype is inferred from the input — here both elements are floats so dtype = float64. Returns a contiguous C array, not a Python list.
⬇ arg: [lam * g_rul, (1 - lam) * g_health] = Python list of two floats. Constructed inline. NumPy copies these into a new (2,) ndarray.
→ walk through at lam=0.5 = lam * g_rul = 0.5 * 5.0 = 2.5 (1 - lam) * g_health = 0.5 * 0.01 = 0.005 List: [2.5, 0.005] ndarray: array([2.5 , 0.005])
→ walk through at lam=0.001996 = lam * g_rul = 0.001996 * 5.0 = 0.00998 (1 - lam) * g_health = 0.998004 * 0.01 = 0.00998004 List: [0.00998, 0.00998004] ndarray: array([0.00998 , 0.00998004])
c = (2,) float64 ndarray. Element 0 is c_rul; element 1 is c_health.
32return c.var()

Compute the population variance of the (2,) contribution vector. For K=2 this collapses algebraically to (1/4)·(c_rul − c_health)² — exactly J1 / 4. Hence J3 and J1 share the SAME argmin, even though they come from different statistical motivations.

EXECUTION STATE
📚 ndarray.var(axis=None, dtype=None, ddof=0) = ndarray reduction. Returns the variance: mean of squared deviations from the mean. Default ddof=0 → POPULATION variance ((1/N)·Σ(x − x̄)²); ddof=1 would give the unbiased SAMPLE variance ((1/(N−1))·Σ(x − x̄)²).
→ why ddof=0 is fine here = We&apos;re not estimating a population from a sample — c is the entire 'population' of two contributions. ddof=0 is the right choice and the NumPy default.
→ algebraic simplification for N=2 = Let a = c_rul, b = c_health. Mean: (a + b) / 2 Deviations: (a - b) / 2 and -(a - b) / 2 Squared: ((a - b) / 2)² each Mean of squared deviations: ((a - b) / 2)² = (a - b)² / 4 So Var(c) = J1 / 4.
→ walk through at lam=0.001996 = c = [0.00998, 0.00998004] mean = 0.00998002 deviations = [-2e-5, +2e-5] squared = [4e-10, 4e-10] Var = 4e-10 ≈ 0 ✓
⬆ return = Non-negative float. Equals (J1/4); both have the same argmin = lambda_closed.
35lam1 = minimize_scalar(squared_gap, bounds=(0, 1), method='bounded').x

Hand the J1 objective to SciPy&apos;s 1-D bounded minimiser. Brent&apos;s algorithm fits a parabola to three sample points, jumps to the parabola&apos;s vertex, and falls back to golden-section bisection if the parabolic fit overshoots. Robust, derivative-free, typically converges in ~30 evaluations.

EXECUTION STATE
📚 minimize_scalar(fun, bracket=None, bounds=None, method='brent', tol=None, options=None) = SciPy&apos;s 1-D minimiser. Returns OptimizeResult — a namespace-style object with .x (the argmin), .fun (the minimum value), .nfev (number of function evaluations), .success (convergence flag), and more.
⬇ arg 1: squared_gap (callable) = The objective function. Must be callable as fun(x) → float for a single Python float x. SciPy will call this many times during the search.
⬇ arg 2: bounds=(0, 1) = Closed search interval (a, b). Required when method='bounded'. Encodes our constraint lam in [0, 1] — outside this range the contributions have no physical meaning (negative weights).
→ why a 2-tuple? = SciPy expects (lo, hi). The minimiser will never evaluate fun outside this interval, so we don&apos;t need defensive 'if lam < 0' guards inside squared_gap.
⬇ arg 3: method='bounded' = Selects the bounded variant of Brent&apos;s method. Other choices: 'brent' (uses bracket, ignores bounds — can wander); 'golden' (pure golden-section, slower than Brent on smooth functions). 'bounded' is the right choice for a constrained interval.
.x = OptimizeResult attribute access. .x is the argmin (a Python float for scalar problems; a NumPy array for vector problems). We discard the rest of the result object.
lam1 = 0.0019960080. Matches lambda_closed to 10 decimal places. SciPy got there in roughly 30 evaluations of squared_gap.
36lam2 = minimize_scalar(neg_min_contribution, bounds=(0, 1), method='bounded').x

Same machinery applied to J2 (max-min fairness via the negation trick). Same bounds, same method, different objective.

EXECUTION STATE
⬇ arg 1: neg_min_contribution = The J2 objective. Brent&apos;s method handles the kink at c_rul = c_health (where the inner min switches sides) — it&apos;s smooth ALMOST everywhere, and the bounded variant doesn&apos;t require differentiability.
→ derivative-free is critical here = J2 has a non-differentiable kink at lambda_closed (the min() flips). Gradient methods would struggle. Brent&apos;s parabolic-and-golden-section hybrid handles non-smooth functions gracefully.
lam2 = 0.0019960080. Same answer to 10 decimals despite the totally different objective shape.
37lam3 = minimize_scalar(variance_contribution, bounds=(0, 1), method='bounded').x

Same again for J3 (minimum variance of contributions).

EXECUTION STATE
⬇ arg 1: variance_contribution = The J3 objective. Builds a length-2 ndarray and returns its variance. Smooth and quadratic in lam (because variance is quadratic in c, and c is linear in lam).
lam3 = 0.0019960080. Same answer — confirming algebraically that Var(c) = J1 / 4 has the same argmin as J1.
40print(f"closed form lam* = {lambda_closed:.10f}")

f-string formatted print. The {lambda_closed:.10f} substitution renders the float to 10 decimal places of fixed-point notation.

EXECUTION STATE
f"..." = Python f-string (formatted string literal). Expressions inside {} are evaluated at runtime and substituted into the string. Adopted in Python 3.6+.
:.10f = Format spec inside the {}: print as fixed-point with 10 digits after the decimal. .10f → 0.0019960080. .6e would give scientific notation 1.996008e-03.
Output = closed form lam* = 0.0019960080
41print(f"min squared gap lam* = {lam1:.10f}")

Print the J1 numerical argmin in the same format so visual comparison is exact column-by-column.

EXECUTION STATE
Output = min squared gap lam* = 0.0019960080
→ matches closed form? = Yes — to all 10 decimal places. Squared-gap minimisation reproduces the closed form.
42print(f"max-min fairness lam* = {lam2:.10f}")

Print the J2 (max-min) numerical argmin.

EXECUTION STATE
Output = max-min fairness lam* = 0.0019960080
→ matches closed form? = Yes — to all 10 decimal places. Max-min fairness reproduces the closed form.
43print(f"min variance of c lam* = {lam3:.10f}")

Print the J3 (variance) numerical argmin.

EXECUTION STATE
Output = min variance of c lam* = 0.0019960080
→ reading the four-line block = All four numbers are byte-identical to 10 decimals. The closed form is simultaneously the unique optimum of THREE different objectives — empirical confirmation of the algebraic claim.
46# ---------- Sensitivity of the closed form ----------

Section header. Now we compute the partial derivatives ∂lam*/∂g_i analytically. These tell us how the closed form responds to perturbations in either gradient norm — critical for understanding why GABA needs an EMA stabiliser on g_health but not on g_rul.

47dlam_drul = -g_health / S ** 2

Quotient rule on lam* = g_health / (g_rul + g_health). Treat g_health as constant, differentiate w.r.t. g_rul: d/dg_rul (g_health / (g_rul + g_health)) = -g_health / (g_rul + g_health)² = -g_health / S².

EXECUTION STATE
** (power) = Python power operator. S ** 2 = S * S = 25.1001.
→ walk through arithmetic = S ** 2 = 5.01 * 5.01 = 25.1001 Numerator: -g_health = -0.01 Quotient: -0.01 / 25.1001 = -3.984048e-04
dlam_drul = ≈ -3.984e-04. Tiny negative — increasing g_rul slightly DECREASES lambda_rul (the formula already down-weights the dominant gradient, and a bigger dominant gradient pushes the down-weight even smaller).
→ engineering meaning = A 1.0-unit jump in g_rul (a 20% relative change) shifts lam* by only 4e-4 (a 20% relative change in lam*). Robust to RUL-side noise.
48dlam_dhealth = g_rul / S ** 2

Same quotient rule, now differentiating w.r.t. g_health: d/dg_health (g_health / (g_rul + g_health)) = (1·(g_rul + g_health) - g_health·1) / (g_rul + g_health)² = g_rul / S². The asymmetry is structural: each partial puts the OTHER gradient in the numerator.

EXECUTION STATE
→ walk through arithmetic = S ** 2 = 25.1001 Numerator: g_rul = 5.0 Quotient: 5.0 / 25.1001 = 0.1992024
dlam_dhealth = ≈ 0.1992. Three orders of magnitude LARGER than dlam_drul. The smaller-gradient channel dominates the sensitivity.
→ why the asymmetry? = lam* is a fraction with g_health in the numerator. When g_health is small, even tiny absolute changes are large RELATIVE changes — and the fraction shifts dramatically.
→ ratio of sensitivities = |dlam_dhealth| / |dlam_drul| = 0.1992 / 3.984e-4 = 500. Same as g_rul / g_health. The sensitivity ratio EQUALS the gradient-norm ratio — fundamental algebraic fact.
→ practical engineering implication = This is exactly why GABA uses EMA(beta=0.99) on the small-gradient side but not on the large-gradient side: per-batch noise on g_health would otherwise produce visible lam* swings. Damp the channel that drives the sensitivity.
49print(f"\nd lam* / d g_rul = {dlam_drul:.6e}")

Print the RUL-side partial derivative in scientific notation. The leading \n inside the f-string emits a blank line first, separating the sensitivity block from the four-line argmin comparison above.

EXECUTION STATE
\n = Newline escape sequence. Forces a blank line before this row of output.
:.6e = Format spec: scientific notation with 6 digits after the decimal in the mantissa. .6e → -3.984048e-04. Useful when the magnitude varies wildly (here -4e-4 vs +0.2 — fixed-point would lose precision on the small one).
Output = (blank line) d lam* / d g_rul = -3.984048e-04
50print(f"d lam* / d g_health = {dlam_dhealth:.6e}")

Final print. The two sensitivities side-by-side make the asymmetry impossible to miss: one is 500x larger than the other.

EXECUTION STATE
Final output (whole script) =
closed form          lam* = 0.0019960080
min squared gap      lam* = 0.0019960080
max-min fairness     lam* = 0.0019960080
min variance of c    lam* = 0.0019960080

d lam* / d g_rul    = -3.984048e-04
d lam* / d g_health = 1.992024e-01
→ final reading = Top block: three numerical methods all reproduce the closed form to 10 decimals — the formula is structural. Bottom block: sensitivity is 500x asymmetric — small-gradient channel needs the EMA, large-gradient channel doesn&apos;t.
17 lines without explanation
1"""Three optimisation objectives, one closed form."""
2
3import numpy as np
4from scipy.optimize import minimize_scalar
5
6
7# ---------- Realistic FD002 numbers from section 12.3 ----------
8g_rul, g_health = 5.0, 0.01
9S = g_rul + g_health
10
11
12# Method 0: closed form (analytic)
13lambda_closed = g_health / S
14
15
16# Method 1: minimise the squared gap between contributions
17def squared_gap(lam: float) -> float:
18    """J1(lam) = (lam * g_rul - (1 - lam) * g_health) ** 2"""
19    return (lam * g_rul - (1 - lam) * g_health) ** 2
20
21
22# Method 2: maximise the minimum contribution (max-min fairness)
23def neg_min_contribution(lam: float) -> float:
24    """J2(lam) = -min(c_rul, c_health). minimise this == maximise the min."""
25    return -min(lam * g_rul, (1 - lam) * g_health)
26
27
28# Method 3: minimise variance of contributions
29def variance_contribution(lam: float) -> float:
30    """J3(lam) = Var([c_rul, c_health])"""
31    c = np.array([lam * g_rul, (1 - lam) * g_health])
32    return c.var()
33
34
35lam1 = minimize_scalar(squared_gap,           bounds=(0, 1), method="bounded").x
36lam2 = minimize_scalar(neg_min_contribution,  bounds=(0, 1), method="bounded").x
37lam3 = minimize_scalar(variance_contribution, bounds=(0, 1), method="bounded").x
38
39
40print(f"closed form          lam* = {lambda_closed:.10f}")
41print(f"min squared gap      lam* = {lam1:.10f}")
42print(f"max-min fairness     lam* = {lam2:.10f}")
43print(f"min variance of c    lam* = {lam3:.10f}")
44
45
46# ---------- Sensitivity of the closed form ----------
47dlam_drul    = -g_health / S ** 2
48dlam_dhealth =  g_rul    / S ** 2
49print(f"\nd lam* / d g_rul    = {dlam_drul:.6e}")
50print(f"d lam* / d g_health = {dlam_dhealth:.6e}")

PyTorch: Gradient Descent on a Learnable λ

Replace the off-the-shelf solver with autograd. Parametrise λ=σ(a)\lambda = \sigma(a) with a single learnable scalar aa and minimise the squared-gap loss with Adam. Convergence to the closed form is the operational check.

Gradient descent recovers the closed form
🐍learn_lambda.py
1Module docstring

States the demo: gradient descent on a single learnable scalar, with a sigmoid-parametrised lambda and the squared-gap loss, must converge to the analytic closed form. If it doesn&apos;t, either the algebra is wrong or the optimiser is mis-tuned. The demo also serves as a cost comparison: 2001 Adam steps to approximate what one division computes exactly.

EXECUTION STATE
→ why this demo exists = It&apos;s the operational answer to 'why not just learn lambda?' The demo shows that gradient descent works in principle (it does converge) but is wasteful in practice (2001 steps vs. 1 division for the same answer).
3import torch

Core PyTorch. Provides Tensor (the GPU-capable autograd-tracked array type) and the namespaces torch.optim (optimisers) and torch.nn.functional (functional layers). For this demo we touch torch.tensor, torch.zeros, torch.sigmoid, and torch.optim.Adam.

EXECUTION STATE
📚 torch = PyTorch&apos;s top-level package. Tensor library with reverse-mode autograd. Tensors are like NumPy arrays but with three extras: GPU support, autograd tracking, and a dynamic computation graph.
→ why PyTorch instead of NumPy here = We need autograd. NumPy has no automatic differentiation, so writing 'minimise loss by gradient descent' would mean hand-deriving and hand-coding the gradient. PyTorch builds the graph as we go and computes ∂loss/∂a for free via .backward().
6# ---------- Realistic FD002 numbers from section 12.3 ----------

Same paper-anchored numbers as the NumPy demo, so the convergence target is directly comparable to the closed-form value computed there.

7g_rul = torch.tensor(5.0)

Build a 0-dim (scalar) float tensor holding the RUL gradient norm. We treat it as a fixed measured quantity — no requires_grad — so autograd ignores it during backward().

EXECUTION STATE
📚 torch.tensor(data, dtype=None, device='cpu', requires_grad=False) = Constructor that copies data into a new Tensor. Default dtype = float32 for floats and int64 for ints. Always copies (use torch.as_tensor to share memory). The 0-dim shape () means a single scalar (different from shape (1,) which is a 1-element vector).
⬇ arg: 5.0 = Python float. Becomes the single value of the new tensor.
g_rul = 0-dim float32 tensor = 5.0. Shape: (). dtype: torch.float32. requires_grad: False (default).
→ why no requires_grad? = g_rul is data, not a learnable parameter. We measured it; we don&apos;t want autograd to compute ∂loss/∂g_rul or include it in any optimiser update.
8g_health = torch.tensor(0.01)

Same construction for the small-gradient side. Same dtype, same shape, same requires_grad=False.

EXECUTION STATE
⬇ arg: 0.01 = Python float.
g_health = 0-dim float32 tensor = 0.01. Note: float32 has only ~7 decimal digits of precision, so 0.01 is actually stored as ≈ 0.009999999776482582 — close enough that it doesn&apos;t affect the demo.
11# Closed form target.

Section header for the analytic reference value we&apos;ll compare against at the end of the script.

12lambda_target = (g_health / (g_rul + g_health)).item()

Compute the closed-form target inside PyTorch (so we use the same float32 arithmetic as the rest of the demo) then unwrap to a Python float via .item() for printing/comparison.

EXECUTION STATE
g_health / (g_rul + g_health) = All operands are 0-dim tensors. Result: 0-dim tensor = 0.01 / 5.01 = 0.001996. PyTorch&apos;s tensor arithmetic broadcasts and tracks autograd graphs even on 0-dim tensors.
📚 Tensor.item() = Method on a 0-element or 0-dim tensor. Returns the underlying value as a Python scalar (int or float). Detaches from autograd. Errors if called on a tensor with more than one element.
→ why .item() vs leaving it a tensor? = Python floats print and format more cleanly than tensors, and we want to do the final |final - target| comparison with stdlib abs(). Item-extraction also breaks the autograd link, which is fine because lambda_target is a constant.
lambda_target = 0.001996 (Python float). The number gradient descent must reach.
15# A learnable scalar 'a'. We parametrise lambda = sigmoid(a) so it is

First half of a two-line comment block explaining WHY we parametrise lambda through a sigmoid rather than constraining it directly.

16# always in [0, 1] without needing a constraint.

Second half. The trick: optimise an unconstrained scalar a ∈ ℝ; let lambda = σ(a) ∈ (0, 1). Adam can move a freely without violating the simplex constraint. Without this re-parametrisation we&apos;d need projected gradient or a constrained solver — both heavier and less robust for a 1-D problem.

17a = torch.zeros(1, requires_grad=True)

Create the single learnable parameter. Shape (1,) — a length-1 vector, not a 0-dim scalar — so optimisers iterate over it cleanly. Initialised at 0 so sigmoid(0) = 0.5 (the neutral midpoint of [0, 1], maximally uninformative starting point).

EXECUTION STATE
📚 torch.zeros(*size, dtype=None, device='cpu', requires_grad=False) = Build a tensor of zeros. *size is variadic — torch.zeros(1) → shape (1,); torch.zeros(2, 3) → shape (2, 3). Default dtype = float32.
⬇ arg 1: 1 (size) = Shape (1,). One scalar parameter wrapped in a length-1 vector. Using torch.zeros(1) instead of torch.tensor(0.0) so the parameter has a definite shape — Adam iterates over each element of each parameter tensor.
⬇ arg 2: requires_grad=True = Marks this tensor as a LEAF parameter that autograd should track. After loss.backward(), a.grad will hold ∂loss/∂a as a tensor of the same shape.
a = Shape (1,) float32 tensor with requires_grad=True. Initial value: tensor([0.]). The single optimiser-tracked parameter.
→ why sigmoid(a) parametrisation? = Sigmoid σ(x) = 1/(1+e^(-x)) maps ℝ → (0, 1). Equivalent to a soft-constrained simplex without Lagrange multipliers or projection steps. σ(0) = 0.5 = perfect neutral start.
→ cost of sigmoid parametrisation = Sigmoid saturates as |a| → ∞: σ&apos;(a) → 0. To reach lam = 0.002, we need a ≈ -6.2, where σ&apos;(-6.2) ≈ 0.002. The gradient on a is multiplied by this near-zero number, so progress slows dramatically near the target — the source of the 'sigmoid plateau' that costs us 2001 iterations.
18optimiser = torch.optim.Adam([a], lr=0.5)

Construct an Adam optimiser bound to our single parameter. Adam (Kingma & Ba 2014) maintains exponentially-weighted running estimates of the first moment (mean of gradients, m) and second moment (uncentred variance, v); each step uses m / (sqrt(v) + ε) so each parameter gets a per-coordinate adaptive learning rate.

EXECUTION STATE
📚 torch.optim.Adam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0) = Adam optimiser. Maintains m (first moment, β₁=0.9) and v (second moment, β₂=0.999). Update: a ← a - lr · m̂ / (√v̂ + ε), where m̂ and v̂ are bias-corrected. Reliable default for most deep-learning training.
⬇ arg 1: [a] (params) = Iterable of parameter tensors to optimise. Must be Python list/tuple of tensors with requires_grad=True. We pass [a] — a list containing the one tensor.
→ why a list? = torch.optim API expects an iterable of parameters (or parameter groups for per-group hyperparams). Even with a single parameter we wrap it in a list. For a real model: optimiser = Adam(model.parameters(), lr=1e-3).
⬇ arg 2: lr=0.5 (learning rate) = Maximum per-step update size before adaptive scaling. Aggressive — typical neural-net training uses 1e-3 to 1e-4. The 1-D well-conditioned objective with sigmoid plateau benefits from a large LR to push past the plateau.
→ why such a high lr? = Sigmoid plateau means the effective gradient on a is tiny (~1e-3 once we&apos;re near the target). With lr=1e-3, each step would move a by ~1e-6 — we&apos;d need millions of steps. lr=0.5 amplifies the small gradient enough to make progress in 2001 steps.
optimiser = torch.optim.Adam instance. Holds: param_groups[0]['params'] = [a]; state[a] = {step: 0, exp_avg: zeros_like(a), exp_avg_sq: zeros_like(a)}. The state grows as we step().
21print(f"{'step':>4} | {'a':>10} {'lambda':>10} {'gap_sq':>12}")

f-string with width-formatted column headers. The :>4 means right-align in a 4-character field; :>10 → 10 chars; :>12 → 12 chars. Quoting the literal labels inside f-strings with single quotes nested inside double quotes lets us format strings the same way as numbers later.

EXECUTION STATE
:>N (format spec) = Right-align in field of width N. Useful for tabular output where columns must line up.
Output = step | a lambda gap_sq
22print("-" * 50)

Python string-multiplication trick: '-' * 50 produces a string of 50 dashes. Cheap divider line under the header.

EXECUTION STATE
* on string = String repetition. 'ab' * 3 → 'ababab'. Used here to draw a horizontal rule the right width.
Output = --------------------------------------------------
23for step in range(2001):

The training loop. range(2001) yields 0, 1, 2, ..., 2000 (note: range is exclusive on the upper end). 2001 steps because the sigmoid plateau near the target requires many small updates — far more than a well-scaled neural-net loss would.

EXECUTION STATE
📚 range(stop) = Python builtin. Returns an iterator over integers 0, 1, ..., stop-1. Lazy — doesn&apos;t materialise the full list. range(2001) iterates 2001 times.
→ why 2001 and not 2000? = Cosmetic: gives us the round number 2000 as a checkpoint without missing the final iteration. range(2001) covers steps 0 through 2000 inclusive.
LOOP TRACE · 5 iterations
step = 0 (initial state, before update)
a = tensor([0.])
lam = sigmoid(0) = 0.5
c_rul = 0.5 * 5.0 = 2.5
c_health = 0.5 * 0.01 = 0.005
loss = (2.5 - 0.005)² = 2.495² = 6.225
→ status = Far from target — loss is huge. Adam will start pushing a strongly negative.
step = 100
a = -5.7877
lam = sigmoid(-5.7877) = 0.003056 — already 1.5x the target (0.001996)
loss = 2.8205e-05 (down from 6.225 — five orders of magnitude in 100 steps)
→ status = Most of the loss reduction happens in the first 100 steps. The remaining 1900 steps wrestle with the sigmoid plateau.
step = 500
a = -5.8518 (only -0.06 change in 400 steps)
lam = 0.002867
loss = 1.9041e-05
→ status = Plateau in full effect. σ&apos;(-5.85) ≈ 0.0029, so each Adam step moves a by only ~lr * grad_a = 0.5 * tiny.
step = 1000
a = -5.9520
lam = 0.002594
loss = 8.9853e-06
step = 2000 (last printed)
a = -6.1114
lam = 0.002213 — within 2.2e-04 of the closed form (0.001996)
loss = 1.1804e-06
→ status = Asymptotic: another 2000 steps would shave the gap to ~1e-04 but never close it cleanly because of float32 noise inside sigmoid.
24lam = torch.sigmoid(a)

Forward pass step 1: map the unconstrained learnable scalar a into the (0, 1) interval via the sigmoid function. σ(x) = 1 / (1 + e^(-x)). Always strictly between 0 and 1; differentiable everywhere; saturates as |x| → ∞.

EXECUTION STATE
📚 torch.sigmoid(input) = Element-wise sigmoid. Equivalent to 1 / (1 + torch.exp(-input)). Implemented as a fused, numerically-stable kernel that avoids overflow for large positive or negative inputs.
⬇ arg: a = 0-dim or shape-(1,) tensor. Our learnable parameter.
→ sigmoid derivative = σ&apos;(x) = σ(x)·(1 - σ(x)). Maximum at x=0 (value 0.25); falls off symmetrically. At x=-6.2: σ&apos;(-6.2) ≈ 0.002·0.998 ≈ 0.002. This near-zero derivative is the sigmoid plateau.
lam = Shape-(1,) tensor in (0, 1). Differentiable in a — autograd records the sigmoid op so gradients flow through it on backward().
25c_rul = lam * g_rul

Forward pass step 2: scale the RUL gradient by the current lambda. This is the effective contribution c_rul that GABA aims to balance.

EXECUTION STATE
* on tensors = Element-wise multiplication. Broadcasts (1,) and () shapes to (1,). Tracked by autograd because lam has requires_grad propagated through it.
lam (operand) = Shape (1,), value in (0, 1).
g_rul (operand) = Shape () = (5.0). No gradient tracked.
c_rul = Shape (1,) tensor. The effective RUL contribution. Differentiable in a via lam.
26c_health = (1 - lam) * g_health

Forward pass step 3: implicit lambda_health = 1 - lam (simplex constraint baked into parametrisation). Multiply by g_health to get the effective health contribution.

EXECUTION STATE
1 - lam = Scalar minus tensor. PyTorch broadcasts the Python int 1 to a tensor, then subtracts. Result has the same shape as lam.
(1 - lam) * g_health = Element-wise multiply, broadcasting. Result shape (1,).
c_health = Shape (1,) tensor. The effective health contribution. Also differentiable in a via lam.
27loss = (c_rul - c_health) ** 2

Forward pass step 4: same J1 squared-gap loss as the NumPy demo. The point of the experiment is to show this loss has its minimum at exactly the closed-form lam — and that gradient descent can find it.

EXECUTION STATE
(c_rul - c_health) = Tensor subtraction. Result shape (1,).
** 2 = Element-wise square. Tensor pow operator. Same as torch.pow(x, 2) or x*x.
loss = Shape-(1,) non-negative tensor. = 0 only at the closed form. This is the scalar we&apos;ll call .backward() on.
→ autograd graph at this point = Reverse-mode chain: a → sigmoid → lam → (lam·g_rul, (1-lam)·g_health) → c_rul - c_health → squared. Backward will walk this chain to compute ∂loss/∂a.
29optimiser.zero_grad()

Clear the .grad attribute of every tracked parameter before backward(). PyTorch ACCUMULATES gradients on each backward() call (so you can sum gradients across micro-batches), so without zeroing first, you&apos;d add the new gradient to last step&apos;s — wrong update direction.

EXECUTION STATE
📚 Optimizer.zero_grad(set_to_none=True) = Optimiser method. For every param in self.param_groups, sets param.grad to None (PyTorch 1.7+ default) or to a zero tensor (older). Setting to None saves memory and is slightly faster.
→ why this is required = Default backward() does grad += new_grad, not grad = new_grad. Forgetting zero_grad() is the most common training-loop bug — gradients from previous steps leak in and the optimiser walks in the wrong direction.
30loss.backward()

Trigger reverse-mode autodiff. PyTorch traverses the computation graph from loss back to every leaf with requires_grad=True (here: just a) and writes the partial derivative into each leaf&apos;s .grad attribute. After this line, a.grad holds ∂loss/∂a as a tensor of the same shape as a.

EXECUTION STATE
📚 Tensor.backward(gradient=None, retain_graph=None, create_graph=False) = Tensor method. Triggers reverse-mode autodiff. For a scalar loss, no gradient arg needed (PyTorch supplies 1.0). For non-scalar loss, you must pass a gradient tensor of matching shape.
→ walk through chain rule for our loss = loss = (c_rul - c_health)² ∂loss/∂lam = 2·(c_rul - c_health)·(g_rul + g_health) = 2·(c_rul - c_health)·S ∂lam/∂a = σ&apos;(a) = lam·(1-lam) ∂loss/∂a = ∂loss/∂lam · ∂lam/∂a (chain rule)
→ at step 0 = c_rul - c_health = 2.495 S = 5.01 ∂loss/∂lam = 2 · 2.495 · 5.01 = 25.0 σ&apos;(0) = 0.25 ∂loss/∂a = 25.0 · 0.25 = 6.25 → a.grad
→ at step 2000 = c_rul - c_health ≈ 1.09e-3 σ&apos;(-6.11) ≈ 0.002 ∂loss/∂a ≈ 2 · 1.09e-3 · 5.01 · 0.002 ≈ 2.18e-5 — tiny, hence the slow plateau
31optimiser.step()

Apply one Adam update: read a.grad, update internal first/second-moment running averages, compute bias-corrected estimates, write the new value back into a. After this line, a has moved by one Adam step in the descent direction.

EXECUTION STATE
📚 Optimizer.step(closure=None) = Optimiser method. For every param: m_t = β₁·m_{t-1} + (1-β₁)·g_t; v_t = β₂·v_{t-1} + (1-β₂)·g_t²; m̂_t = m_t/(1-β₁ᵗ); v̂_t = v_t/(1-β₂ᵗ); param ← param - lr·m̂_t/(√v̂_t + ε).
→ first step in detail (step=0) = a.grad = 6.25 m_1 = 0.1·6.25 = 0.625 v_1 = 0.001·39.0625 = 0.039 m̂_1 = 0.625 / (1 - 0.9) = 6.25 v̂_1 = 0.039 / (1 - 0.999) = 39.0625 update = 0.5 · 6.25 / (√39.0625 + 1e-8) = 0.5 · 6.25 / 6.25 = 0.5 a ← 0 - 0.5 = -0.5
33if step in (0, 100, 500, 1000, 2000):

Print only at five checkpoint steps so the convergence log is compact. Pythonic idiom: 'in tuple' is O(N) but N=5, so trivial.

EXECUTION STATE
step in (0, 100, 500, 1000, 2000) = Membership test. True for the five logged steps; False for the other 1996. Skipping print on every step would dump 2001 lines.
34print(f"{step:>4} | {a.item():>10.4f} {lam.item():>10.6f} {loss.item():>12.4e}")

Print one row of the convergence table. Each .item() unwraps a 0-dim or 1-element tensor into a Python scalar so the format spec applies to numbers, not tensors.

EXECUTION STATE
:>10.4f = Right-align in width 10, fixed-point with 4 decimals. 0.500000 → ' 0.5000'.
:>10.6f = Right-align in width 10, fixed-point with 6 decimals. 0.5 → ' 0.500000'.
:>12.4e = Right-align in width 12, scientific notation with 4 decimals. 6.225 → ' 6.2250e+00'. Used for loss because its magnitude spans many orders.
Output (step=0 row) = 0 | -0.5000 0.500000 6.2250e+00
37# ---------- Verify convergence to the closed form ----------

Section header for the post-training comparison block.

38final_lambda = torch.sigmoid(a).item()

Compute lambda one final time from the converged a, then unwrap to a Python float for printing and the abs() comparison on line 41.

EXECUTION STATE
torch.sigmoid(a) = Re-evaluate sigmoid at the final a ≈ -6.11. Returns a 0-dim tensor ≈ 0.002213.
.item() = Unwrap the single-element tensor to a Python float.
final_lambda = ≈ 0.002213. Compare against the closed form 0.001996.
39print(f"\nfinal lambda = {final_lambda:.6f}")

Print the converged lambda. Leading \n separates this verification block from the table above.

EXECUTION STATE
\n = Newline escape inside the f-string. Emits a blank line before this row.
:.6f = Fixed-point, 6 decimals.
Output = (blank line) final lambda = 0.002213
40print(f"closed-form lam* = {lambda_target:.6f}")

Print the analytic target so it sits directly under the converged value for visual comparison.

EXECUTION STATE
Output = closed-form lam* = 0.001996
41print(f"|final - lam*| = {abs(final_lambda - lambda_target):.2e}")

Final convergence diagnostic. abs() returns the magnitude of the difference; :.2e formats it in scientific notation with 2 decimals so the order of magnitude is obvious.

EXECUTION STATE
📚 abs(x) = Python builtin absolute value. For floats: returns |x|. For tensors: would call torch.abs(); here we&apos;re past .item() so we&apos;re working with a Python float.
final_lambda - lambda_target = 0.002213 - 0.001996 = 0.000217 (positive — gradient descent slightly overshoots high).
:.2e = Scientific notation, 2 decimals. 0.000217 → 2.17e-04.
Final output (whole script) =
step |          a     lambda       gap_sq
--------------------------------------------------
   0 |    -0.5000   0.500000   6.2250e+00
 100 |    -5.7877   0.003056   2.8205e-05
 500 |    -5.8518   0.002867   1.9041e-05
1000 |    -5.9520   0.002594   8.9853e-06
2000 |    -6.1114   0.002213   1.1804e-06

final lambda     = 0.002213
closed-form lam* = 0.001996
|final - lam*|   = 2.17e-04
→ final reading = Adam recovers the closed form to within 2.17e-04 after 2001 iterations. The residual gap is sigmoid-saturation noise, not a methodological error — and is the operational answer to 'why GABA uses the closed form, not gradient descent': one division gives byte-exact accuracy; 2001 forward+backward passes give 4 significant digits.
13 lines without explanation
1"""Gradient descent on a learnable lambda converges to the closed form."""
2
3import torch
4
5
6# ---------- Realistic FD002 numbers from section 12.3 ----------
7g_rul    = torch.tensor(5.0)
8g_health = torch.tensor(0.01)
9
10
11# Closed form target.
12lambda_target = (g_health / (g_rul + g_health)).item()
13
14
15# A learnable scalar 'a'. We parametrise lambda = sigmoid(a) so it is
16# always in [0, 1] without needing a constraint.
17a = torch.zeros(1, requires_grad=True)
18optimiser = torch.optim.Adam([a], lr=0.5)
19
20
21print(f"{'step':>4} | {'a':>10} {'lambda':>10} {'gap_sq':>12}")
22print("-" * 50)
23for step in range(2001):
24    lam      = torch.sigmoid(a)
25    c_rul    = lam * g_rul
26    c_health = (1 - lam) * g_health
27    loss     = (c_rul - c_health) ** 2
28
29    optimiser.zero_grad()
30    loss.backward()
31    optimiser.step()
32
33    if step in (0, 100, 500, 1000, 2000):
34        print(f"{step:>4} | {a.item():>10.4f} {lam.item():>10.6f} {loss.item():>12.4e}")
35
36
37# ---------- Verify convergence to the closed form ----------
38final_lambda = torch.sigmoid(a).item()
39print(f"\nfinal lambda     = {final_lambda:.6f}")
40print(f"closed-form lam* = {lambda_target:.6f}")
41print(f"|final - lam*|   = {abs(final_lambda - lambda_target):.2e}")
So why does the production GABA loss not learn λ\lambda by gradient descent? Because the closed form is free. Computing ghealth/Sg_{\text{health}} / S is one division per step; running 2,001 Adam steps to approximate it is ~2,001 forwards + backwards. GABA uses the algebraic answer and spends the saved compute on the actual backbone.

Sensitivity: Why the Closed Form Is Stable

Differentiate the closed form analytically:

λrulgrul=ghealth(grul+ghealth)2,λrulghealth=grul(grul+ghealth)2\frac{\partial \lambda^*_{\text{rul}}}{\partial g_{\text{rul}}} = -\frac{g_{\text{health}}}{(g_{\text{rul}} + g_{\text{health}})^2}, \qquad \frac{\partial \lambda^*_{\text{rul}}}{\partial g_{\text{health}}} = \frac{g_{\text{rul}}}{(g_{\text{rul}} + g_{\text{health}})^2}

At the realistic (grul,ghealth)=(5.0,0.01)(g_{\text{rul}}, g_{\text{health}}) = (5.0, 0.01) these evaluate to 3.98×104-3.98 \times 10^{-4} and +0.199+0.199:

PerturbationWhat changesEffect on λ*Robustness verdict
+10% on the LARGE gradient (g_rul)5.0 → 5.5λ* shifts from 0.001996 to 0.001815Robust: −9% relative shift, mostly absorbed
+10% on the SMALL gradient (g_health)0.01 → 0.011λ* shifts from 0.001996 to 0.002195Sensitive: ~+10% relative shift
Symmetric noise (zero-mean) on bothPer-batch jitterEMA(β=0.99) damps to <1% jitterStabilised by EMA
Sustained drift (training-time non-stationarity)Slow change in either gradientλ* tracks smoothly via EMATracks correctly; no oscillation

The asymmetry is fundamental: the small-gradient task gets the partial derivative with the LARGE numerator. That is also why GABA's EMA stabiliser (β=0.99\beta = 0.99) is critical — it's on the small-gradient channel that per-batch noise would otherwise produce visible λ\lambda^* oscillation.

The Same Closed Form In Other Fields

The pattern λi=aj/(ai+aj)\lambda_i = a_j / (a_i + a_j) appears under different names whenever a two-party allocation equalises an effort-times-rate quantity:

FieldTwo-party splitClosed-form ruleWhat gets equalised
Networking (proportional fairness, Kelly 1998)Two TCP flows, one shared linkthroughput_i ∝ 1/RTT_irate × RTT (link occupancy)
Finance (risk parity)Two-asset portfolioweight_i ∝ 1/σ_iweight × σ (risk contribution)
Game theory (Nash bargaining, K=2)Bilateral split of joint surplusshare_i = (1 - share_j) by symmetrylog-utility increment
Climate-model ensembles (CMIP6 inverse variance)Two-model meanweight_i ∝ 1/var_iweight × variance
Physics (parallel resistors)Current through two parallel pathsI_i = R_j / (R_i + R_j) · I_totalvoltage drop
Economics (Cournot duopoly with linear demand)Two firms&apos; quantitiesq_i ∝ (a − c_j)marginal revenue
RUL prediction (this book)RUL + health on shared backboneλ_i = g_j / (g_i + g_j)λ × ‖g‖ (effective gradient contribution)

The Kelly proportional-fairness derivation in 1998 is algebraically identical to GABA's K=2 closed form — the two papers were just published 27 years apart in different fields.

Pitfalls When Using the Closed Form

Pitfall 1: Applying the K=2 form when K>2. For K>2 tasks the formula λi=gj/(gi+gj)\lambda_i = g_j / (g_i + g_j) is undefined — there is no single ‘other’ gradient. The K-task generalisation in §17.1 collapses to this expression only when K=2. Using the K=2 form in K=3 code silently picks one task to ignore.
Pitfall 2: Computing gig_i on different parameter sets. Both gradient norms must be measured on the SAME shared parameters. A common bug: computing grulg_{\text{rul}} on the entire model (including the RUL head) and ghealthg_{\text{health}} only on the backbone. The closed form is then meaningless because the two norms are not comparable.
Pitfall 3: Dividing by zero when one gradient vanishes. If grul=ghealth=0g_{\text{rul}} = g_{\text{health}} = 0 (the model is at a stationary point of both losses), the closed form is 0/00/0. The paper's implementation adds +1012+10^{-12} to the denominator and floors λiλmin=0.05\lambda_i \geq \lambda_{\min} = 0.05; these stabilisers handle the corner case.
Why the closed form survives noisy gradients. The sensitivity λ/ghealth0.2\partial \lambda^* / \partial g_{\text{health}} \approx 0.2 means a 5% error in ghealthg_{\text{health}} produces a 1% absolute shift in λ\lambda^*. EMA smoothing with β=0.99\beta = 0.99 further attenuates per-batch fluctuations by 1β20.14\sqrt{1 - \beta^2} \approx 0.14. Combined, the algorithm tracks the closed form to within a few-percent jitter even on stochastic gradients — a property formally established in §19.

Takeaway

  • The K=2 closed form is λi=gj/(gi+gj)\lambda^*_i = g_j / (g_i + g_j). One division. No iteration. No tuning.
  • Three independent derivations give the same answer. Lagrangian on equal-contribution; max-min fairness LP; minimum-variance objective. The formula is structural, not formulation-dependent.
  • Numerical optimisers and gradient descent both recover it. SciPy's bounded Brent finds it to 10 decimals; PyTorch Adam approaches it within 10410^{-4} after ~2,000 steps. The remainder is sigmoid saturation, not formulation error.
  • The closed form is sensitivity-asymmetric. λ/ghealth\partial \lambda^*/\partial g_{\text{health}} is 500× larger than λ/grul\partial \lambda^*/\partial g_{\text{rul}} at the realistic operating point. This is exactly why GABA's EMA stabiliser smooths the small-gradient channel.
  • The same formula appears in networking, finance, physics, and game theory. Two-party inverse-proportional allocation is a cross-domain invariant; GABA is the gradient-space instance of a decades-old fairness law.
  • Production GABA uses the closed form, not gradient descent. It's 2,000× cheaper per step and exactly correct.
Loading comments...