this is the refactored code ```cpp
float3 SampleTexSquared(Texture2D<float4> src, int2 pos) {
float3 c = src.Load(int3(pos, 0)).xyz;
return c * c; // shader works in squared color-space and sqrt() at the end
}
float3 CAS9_Shader(int2 centerPos, Texture2D<float4> src, float4 cb0[3]) {
// load 3x3 neighborhood (readable names)
float3 a = SampleTexSquared(src, centerPos + int2(-1,-1)); // tl
float3 b = SampleTexSquared(src, centerPos + int2( 0,-1)); // up
float3 c = SampleTexSquared(src, centerPos + int2( 1,-1)); // tr
float3 d = SampleTexSquared(src, centerPos + int2(-1, 0)); // left
float3 e = SampleTexSquared(src, centerPos + int2( 0, 0)); // center
float3 f = SampleTexSquared(src, centerPos + int2( 1, 0)); // right
float3 g = SampleTexSquared(src, centerPos + int2(-1, 1)); // bl
float3 h = SampleTexSquared(src, centerPos + int2( 0, 1)); // down
float3 i = SampleTexSquared(src, centerPos + int2( 1, 1)); // br
// soft min / max over green channel (decompiled used doubled accumulators)
float maxG = max(max(max(b.g, d.g), max(e.g, f.g)),
max(max(max(a.g, c.g), max(g.g, h.g)), i.g));
float minG = min(min(min(b.g, d.g), min(e.g, f.g)),
min(min(min(a.g, c.g), min(g.g, h.g)), i.g));
if (maxG == 0.0f) {
return e; // guard, matches decompiled early return (no sharpening)
}
// compute normalized conservative contrast metric:
float amplify = saturate( min(minG, 2.0f - maxG) * (1.0f / maxG) );
// shape the response with the quadratic shaper used in the decompiled code
float shaped = amplify * 2.0f;
shaped = -amplify * amplify + shaped; // -x^2 + 2x
shaped = max(0.5f, shaped); // floor to avoid zero
// build polynomial-driven per-sample coefficients (cb0 used like tiny scalars)
// -- these mimic the r11 / r1 assembly in the decompile --
float2 scaleParams = shaped * float2(1.5f, 1.5f) + float2(1.0f, 2.0f);
float4 aInput = (float4)scaleParams.x * cb0[2].xxxx + cb0[2].yyyy; // example mapping
float4 bInput = (float4)scaleParams.y * cb0[2].xxxx - cb0[2].yzyz;
// polynomial-ish transforms, squaring/clamping similar to decompiled transforms
float4 A = saturate(aInput * aInput);
float4 B = saturate(bInput * bInput);
// combine into two coefficient vectors (per-channel)
float4 coeffSet0 = float4(1,1,1,1) + (-scaleParams.y * A + (A * scaleParams.x));
float4 coeffSet1 = float4(1,1,1,1) + (-scaleParams.y * B + (B * scaleParams.x));
// apply coefficients to different neighbors (mapping follows decompiled pattern)
// example mapping: coeffSet0.xxx -> up, coeffSet0.yyy -> top-right, coeffSet0.zzz -> left, coeffSet0.www -> bottom-left
float3 accum = float3(0,0,0);
float totalWeight = 0.0f;
// top-left + top + top-right
accum += a * coeffSet0.xxx; totalWeight += coeffSet0.x;
accum += b * coeffSet0.xxx; totalWeight += coeffSet0.x;
accum += c * coeffSet0.yyy; totalWeight += coeffSet0.y;
// left + center + right
accum += d * coeffSet0.zzz; totalWeight += coeffSet0.z;
accum += e * coeffSet1.xxx; totalWeight += coeffSet1.x; // center uses other set
accum += f * coeffSet1.yyy; totalWeight += coeffSet1.y;
// bottom-left + bottom + bottom-right
accum += g * coeffSet0.www; totalWeight += coeffSet0.w;
accum += h * coeffSet1.zzz; totalWeight += coeffSet1.z;
accum += i * coeffSet1.www; totalWeight += coeffSet1.w;
// normalize by reciprocal of sum of weights
float invSum = rcp(totalWeight + 1e-6f); // decompiled summed and rcp; tiny eps for robustness
float3 result = saturate(accum * invSum);
// undo squared color-space
return sqrt(result);
}