// Copyright 2010 Sundog Software, LLC. All rights reserved worldwide.

// The following parameters may be adjusted to tweak appearance:

// Maximum number of slices to render through the cloud volume.
// Increase this to reducing sampling artifacts. Decrease it
// to increase performance.
#define MAX_SAMPLES 150

// The rate at which detail is added to the clouds with distance
// from the camera. This is needed to counteract sampling artifacts
// from high frequency detail near the camera.
#define DETAIL_FALLOFF 2.0

// Number of samples to take for lighting each slice. Increase for
// more lighting accuracy, decrease for more performance. Only used
// if SMOOTH_LIGHTING is not defined
#define LIGHTING_DEPTH 3

// Uncomment to enable Reinhard tone mapping - tends to make things too 
// transparent though
//#define TONE_MAP

// Smooth lighting approximates the cloud's lighting by just looking
// at the sample's position within the cloud layer relative to the 
// light source. It's faster and produces less detail in the lighting
// which actually looks better.
#define SMOOTH_LIGHTING

// Controls how light linearly falls off as a function of depth into
// the cloud layer
#define SMOOTH_LIGHTING_BRIGHTNESS 2.0

// The frequency of the noise used to add detail
#define NOISE_FREQUENCY 2500

// Enable or disable gamma correction
#define GAMMA_CORRECTION

// Gamma value
#define GAMMA 2.2

// Exposure
#define EXPOSURE 1.5

// Opacity boost
#define OPACITY 1.0

// Opacity threshold at which we do early ray termination
#define OPACITY_THRESHOLD 0.9

// Epsilon value at which we discard the fragment
#define EPSILON 0.001

// Whether to disable tone mapping and gamma correction
#define HDR

uniform float4x4 projectionMatrix;
uniform Texture3D gDiffuseMap;
uniform Texture3D gDiffuseMap2;
uniform float4 lightColor;
uniform float4 lightObjectDirAndConstTerm;
uniform float4 lightTexCoords;
uniform float4 viewSampleDimensions;
uniform float4 lightSampleDimensions;
uniform float4 voxelDimensions;
uniform float4 lightWorldDirAndExtinction;
uniform float4 cameraTexCoords;
uniform float4 skyLightColor;
uniform float4 originTexCoords;
uniform float4 fadeFlag;
uniform float4 fogColorAndDensity;
uniform float4 multipleScatteringTerm;
uniform float4 noiseOffset;
uniform float4 extinctionCoefficient;
uniform float4 jitter;
uniform float4 unitScale;

struct SL_Vertex
{
    float4 pos :
    SV_Position;
    float4 color :
    COLOR0;
    float3 tex :
    TEXCOORD0;
    float3 eyeDepth :
    TEXCOORD1;
};

#ifdef DX9
sampler3D gNoiseSampler = sampler_state
{
    Texture = <gDiffuseMap2>;
    MinFilter = LINEAR;
    MagFilter = LINEAR;
    AddressU = WRAP;
    AddressV = WRAP;
    AddressW = WRAP;
};

sampler3D gCloudSampler = sampler_state
{
    Texture = <gDiffuseMap>;
    MinFilter = LINEAR;
    MagFilter = LINEAR;
    AddressU = WRAP;
    AddressV = CLAMP;
    AddressW = WRAP;
};
#else
SamplerState gNoiseSampler
{
    Filter = MIN_MAG_MIP_LINEAR;
    AddressU = WRAP;
    AddressV = WRAP;
    AddressW = WRAP;
};

SamplerState gCloudSampler
{
    Filter = MIN_MAG_MIP_LINEAR;
    AddressU = WRAP;
    AddressV = CLAMP;
    AddressW = WRAP;
};
#endif

void VS(    float4 position : POSITION,
            float4 color : COLOR,
            float3 texCoord : TEXCOORD0,

            out SL_Vertex oVert
            )
{
    // The vertex is in eye coords
    oVert.color = color;

    oVert.tex = texCoord;

    oVert.eyeDepth.x = -position.z;
    oVert.eyeDepth.y = 0;
    oVert.eyeDepth.z = 0;

    oVert.pos = mul(projectionMatrix, position);
}

void getVolumeIntersection(in float3 pos, in float3 dir, out float tNear, out float tFar)
{
    // Intersect the ray with each plane of the box
    float3 invDir = 1.0 / dir;
    float3 tBottom = -pos * invDir;
    float3 tTop = (1.0 - pos) * invDir;

    // Find min and max intersections along each axis
    float3 tMin = min(tTop, tBottom);
    float3 tMax = max(tTop, tBottom);

    // Find largest min and smallest max
    float2 t0 = max(tMin.xx, tMin.yz);
    tNear = max(t0.x, t0.y);
    t0 = min(tMax.xx, tMax.yz);
    tFar = min(t0.x, t0.y);

    // Clamp negative intersections to 0
    tNear = max(0.0, tNear);
    tFar = max(0.0, tFar);
}

float phaseFunction(in float cosa)
{
    // Rayleigh function
    return 0.75 * (1.0 + cosa * cosa);
}

void toneMap(inout float4 c)
{
    c *= float4(EXPOSURE, EXPOSURE, EXPOSURE, OPACITY);

    // Simplified Reinhard tone mapping operator
#ifdef TONE_MAP
    c = (c / (c + 1.0));
#endif

#ifdef GAMMA_CORRECTION
    const float gamma = 1.0 / GAMMA;
    c = pow(c, float4(gamma, gamma, gamma, gamma));
#endif

	c = clamp(c, 0.0, 1.0);
}

float computeFade(in float3 posWithinVolume)
{
    if (fadeFlag.x > 0)
    {
        float distFromCenter = length(posWithinVolume.xz - float2(0.5, 0.5)) * 2.0;
        return 1.0 - exp(-1.0 * fadeFlag.z * distFromCenter * distFromCenter * distFromCenter);
    }
    else
    {
        return 0.0;
    }
}

void computeFog(inout float4 c, in float eyeDepth)
{
    float fogExponent = clamp(fogColorAndDensity.w * eyeDepth, 0.0, 1.0);
    float f = exp(-abs(fogExponent));
    f = clamp(f, 0.0, 1.0);
    c.xyz = lerp(fogColorAndDensity.xyz, c.xyz, f);
}

float getCloudDensity(in float3 texCoord, in float t)
{	
	float3 perturb = float3(0,0,0);
	float4 uvw;
	uvw.xyz = ((texCoord + noiseOffset.xyz) / viewSampleDimensions.xyz) / (NOISE_FREQUENCY * unitScale);
	uvw.w = 1.0;
	
#ifdef DX9
	perturb.xyz += 1.0    * tex3Dlod(gNoiseSampler,  2.0 * uvw).xyz - 0.5;
	perturb.xyz += 0.5    * tex3Dlod(gNoiseSampler,  4.0 * uvw).xyz - 0.25;
	float4 sampleCoord;
	sampleCoord.xyz = texCoord + perturb * jitter.xyz * t;
	sampleCoord.w = 1.0;
	return tex3Dlod(gCloudSampler, sampleCoord).x;
#else
	perturb.xyz += 1.0 * gDiffuseMap2.SampleLevel(gNoiseSampler, 2.0 * uvw, 0).xyz - 0.5;
	perturb.xyz += 0.5 * gDiffuseMap2.SampleLevel(gNoiseSampler, 4.0 * uvw, 0).xyz - 0.25;
	float4 sampleCoord;
	sampleCoord.xyz = texCoord + perturb * jitter.xyz * t;
	sampleCoord.w = 1.0;
	return gDiffuseMap.SampleLevel(gCloudSampler, sampleCoord, 0).x;
#endif
}

float4 PS(SL_Vertex inVert) : SV_TARGET
{
    float3 texCoord = inVert.tex;

    float4 texCoord4;
    texCoord4.xyz = texCoord + originTexCoords.xyz;
    texCoord4.w = 1.0;

    float3 view = texCoord - cameraTexCoords.xyz;
    float3 viewDir = normalize(view);

    // Find the intersections of the volume with the viewing ray
    float tminView, tmaxView;
    getVolumeIntersection(cameraTexCoords.xyz, viewDir, tminView, tmaxView);

    float3 vv = voxelDimensions.xyz;
    vv = vv * viewDir;
    float sampleSize = length((tmaxView - tminView) * vv);
    sampleSize /= MAX_SAMPLES;
    float opticalDepth = extinctionCoefficient.x * sampleSize;
    float correctedOpacity = 1.0 - exp(-opticalDepth);
    
    float constTerm = lightObjectDirAndConstTerm.w;

    float4 fragColor = float4(0, 0, 0, 0);

    float viewInc = (tmaxView - tminView) / MAX_SAMPLES;

    float3 lightSampleInc = lightSampleDimensions.xyz * lightTexCoords.xyz;
#ifdef SMOOTH_LIGHTING
	float3 lightDir = normalize(lightSampleInc);
#else
    float extinctionLight = lightWorldDirAndExtinction.w;
#endif

    float3 ambientTerm = skyLightColor.xyz + multipleScatteringTerm.xyz;

    float t = tminView;

    float cosa = dot(normalize( vv ), lightTexCoords.xyz);
    float phase = phaseFunction(cosa);

    float3 scattering = lightColor.xyz * constTerm;

    // eyeDepth is a depth of this fragment on cloud box face. 
    // these are front faces when camera looks from outside (tminView > 0) and
    // back faces when camera is inside the box (tminView == 0).
    float depthFactor = inVert.eyeDepth.x / ( tminView > 0. ? tminView : tmaxView );

    for (int sampleNum = 0; sampleNum < MAX_SAMPLES; sampleNum++)
    {
        float3 sampleTexCoords = (cameraTexCoords.xyz + originTexCoords.xyz) + viewDir * t;
	    
        t += viewInc;

		float texel = getCloudDensity(sampleTexCoords, 1.0 - exp(-DETAIL_FALLOFF * t));

        if (texel < EPSILON) continue;

        float fade = computeFade(sampleTexCoords.xyz - originTexCoords.xyz);


        // apply lighting
#ifndef SMOOTH_LIGHTING
        float4 accumulatedColor = lightColor;
        float4 samplePos;
        samplePos.xyz = sampleTexCoords.xyz + lightSampleInc * LIGHTING_DEPTH;
        samplePos.w = 1.0;

        float3 scattering = lightColor.xyz * constTerm;

        for (int i = 0; i < LIGHTING_DEPTH; i++)
        {
        #ifdef DX9
            float lightSample = tex3Dlod(gCloudSampler, samplePos).x;
        #else
            float lightSample = gDiffuseMap.SampleLevel(gCloudSampler, samplePos.xyz, 0).x;
        #endif
            if (lightSample != 0)
            {
                float4 srcColor;
                srcColor.xyz = accumulatedColor.xyz * scattering;
                srcColor.w = extinctionLight;
                srcColor *= lightSample;

                accumulatedColor = srcColor + (1.0 - srcColor.w) * accumulatedColor;
            }

            samplePos.xyz -= lightSampleInc;
        }

        float4 fragSample;
        fragSample.xyz = accumulatedColor.xyz * phase + ambientTerm;
#else   
        float tNear, tFar;
        getVolumeIntersection(sampleTexCoords - originTexCoords.xyz, lightDir, tNear, tFar);
        float cdepth = (1.0 - tFar) * SMOOTH_LIGHTING_BRIGHTNESS;
        float4 fragSample;
        fragSample.xyz = scattering * cdepth * phase + ambientTerm;
#endif

        fragSample.w = correctedOpacity;

        // apply fog
        float depth = t * depthFactor;
        computeFog(fragSample, depth);

        // apply fading
        fragSample *= (1.0 - fade);

        // apply texture
        fragSample *= texel;

        // Under operator for compositing:
        fragColor = fragColor + (1.0 - fragColor.w) * fragSample;

        // Early ray termination!
        if (fragColor.w > OPACITY_THRESHOLD)
        {
            break;
        }
    }

#ifndef HDR
    toneMap(fragColor);
#endif

    return fragColor * inVert.color;
}

#ifdef DX11
technique11 ColorTech
{
    pass P0
    {
        SetVertexShader( CompileShader( vs_5_0, VS() ) );
        SetGeometryShader( NULL );
        SetPixelShader( CompileShader( ps_5_0, PS() ) );
    }
}
#endif

#ifdef DX10
technique10 ColorTech
{
    pass P0
    {
        SetVertexShader( CompileShader( vs_4_0, VS() ) );
        SetGeometryShader( NULL );
        SetPixelShader( CompileShader( ps_4_0, PS() ) );
    }
}
#endif

#ifdef DX10LEVEL9
technique10 ColorTech
{
    pass P0
    {
        SetVertexShader( CompileShader( vs_4_0_level_9_1, VS() ) );
        SetGeometryShader( NULL );
        SetPixelShader( CompileShader( ps_4_0_level_9_1, PS() ) );
    }
}
#endif

#ifdef DX9
technique
{
    pass P0
    {
        SetVertexShader( CompileShader( vs_3_0, VS() ) );
        SetPixelShader( CompileShader( ps_3_0, PS() ) );
    }
}
#endif
