[Mesa-dev] [PATCH 3/4] nir: add all combinations of conversions with rounding and saturation

Jason Ekstrand jason at jlekstrand.net
Sat Apr 28 15:29:52 UTC 2018


On Sat, Apr 28, 2018 at 4:14 AM, Karol Herbst <kherbst at redhat.com> wrote:

> OpenCL has explicit casts where one can specify the rounding mode and put a
> sat modifier:
>
> https://www.khronos.org/registry/OpenCL/sdk/2.1/docs/
> man/xhtml/convert_T.html
>
> _sat is valid for all conversions to an integer type and rounding modes are
> valid for all conversions involving floats.
>
> Allthough the FPRoundingMode modifier is allowed without any restrictions
> in
> capabilities, it can only be used together with fp16 in GLSL. Additionally
> it
> can be used for conversions to/from floating points in OpenCL.
>
> The SaturatedConversion modifier, OpSatConvertUToS and OpSatConvertSToU are
> only supported for Kernels, so current drivers are safe.
>
> Signed-off-by: Karol Herbst <kherbst at redhat.com>
> ---
>  src/compiler/glsl/glsl_to_nir.cpp |   2 +-
>  src/compiler/nir/nir.h            |   2 +-
>  src/compiler/nir/nir_opcodes.py   |  28 +++++-----
>  src/compiler/nir/nir_opcodes_c.py |  26 +++++----
>  src/compiler/spirv/spirv_to_nir.c |   4 +-
>  src/compiler/spirv/vtn_alu.c      | 108 ++++++++++++++++++++++++------
> --------
>  src/compiler/spirv/vtn_glsl450.c  |   2 +-
>  src/compiler/spirv/vtn_private.h  |   2 +-
>  8 files changed, 107 insertions(+), 67 deletions(-)
>
> diff --git a/src/compiler/glsl/glsl_to_nir.cpp
> b/src/compiler/glsl/glsl_to_nir.cpp
> index 8e5e9c34912..fcb6ef27e47 100644
> --- a/src/compiler/glsl/glsl_to_nir.cpp
> +++ b/src/compiler/glsl/glsl_to_nir.cpp
> @@ -1589,7 +1589,7 @@ nir_visitor::visit(ir_expression *ir)
>        nir_alu_type src_type = nir_get_nir_type_for_glsl_
> base_type(types[0]);
>        nir_alu_type dst_type = nir_get_nir_type_for_glsl_
> base_type(out_type);
>        result = nir_build_alu(&b, nir_type_conversion_op(src_type,
> dst_type,
> -                                 nir_rounding_mode_undef),
> +                                 nir_rounding_mode_undef, false),
>                                   srcs[0], NULL, NULL, NULL);
>        /* b2i and b2f don't have fixed bit-size versions so the builder
> will
>         * just assume 32 and we have to fix it up here.
> diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
> index f3326e6df94..f32e5bd8bb2 100644
> --- a/src/compiler/nir/nir.h
> +++ b/src/compiler/nir/nir.h
> @@ -784,7 +784,7 @@ nir_get_nir_type_for_glsl_type(const struct glsl_type
> *type)
>  }
>
>  nir_op nir_type_conversion_op(nir_alu_type src, nir_alu_type dst,
> -                              nir_rounding_mode rnd);
> +                              nir_rounding_mode rnd, bool saturation);
>
>  typedef enum {
>     NIR_OP_IS_COMMUTATIVE = (1 << 0),
> diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_
> opcodes.py
> index f4cd175bc6a..9c51f77bf1b 100644
> --- a/src/compiler/nir/nir_opcodes.py
> +++ b/src/compiler/nir/nir_opcodes.py
> @@ -168,26 +168,28 @@ unop("flog2", tfloat, "log2f(src0)")
>
>  # Generate all of the numeric conversion opcodes
>  for src_t in [tint, tuint, tfloat]:
> -   if src_t in (tint, tuint):
> -      dst_types = [tfloat, src_t]
> -   elif src_t == tfloat:
> -      dst_types = [tint, tuint, tfloat]
> -
> -   for dst_t in dst_types:
> +   for dst_t in [tint, tuint, tfloat]:
>        if dst_t == tfloat:
>           bit_sizes = [16, 32, 64]
> +         sat_modes = ['']
>        else:
>           bit_sizes = [8, 16, 32, 64]
> +         if src_t != tfloat and dst_t != src_t:
> +            sat_modes = ['_sat']
> +         else:
> +            sat_modes = ['_sat', '']
>        for bit_size in bit_sizes:
> -          if dst_t == tfloat and src_t == tfloat:
> -              rnd_modes = ['_rtne', '_rtz', '']
> -              for rnd_mode in rnd_modes:
> +          for sat_mode in sat_modes:
> +              if src_t == tfloat or dst_t == tfloat:
> +                  for rnd_mode in ['_rtne', '_rtz', '_ru', '_rd', '']:
> +                      unop_convert("{0}2{1}{2}{3}{4}".format(src_t[0],
> dst_t[0],
> +                                                             bit_size,
> rnd_mode,
> +                                                             sat_mode),
> +                                   dst_t + str(bit_size), src_t, "src0")
> +              else:
>                    unop_convert("{0}2{1}{2}{3}".format(src_t[0], dst_t[0],
> -                                                       bit_size,
> rnd_mode),
> +                                                      bit_size, sat_mode),
>                                 dst_t + str(bit_size), src_t, "src0")
> -          else:
> -              unop_convert("{0}2{1}{2}".format(src_t[0], dst_t[0],
> bit_size),
> -                           dst_t + str(bit_size), src_t, "src0")
>

As I mentioned on IRC, we need proper constant folding.  Getting rounding
modes on f32->f16 wrong isn't good and I probably shouldn't have let it
through.  Let's not make the problem worse.  Not correctly handling _sat is
especially bad.


>
>  # We'll hand-code the to/from bool conversion opcodes.  Because bool
> doesn't
>  # have multiple bit-sizes, we can always infer the size from the other
> type.
> diff --git a/src/compiler/nir/nir_opcodes_c.py b/src/compiler/nir/nir_
> opcodes_c.py
> index 19079f86e7b..9b8642f0cc1 100644
> --- a/src/compiler/nir/nir_opcodes_c.py
> +++ b/src/compiler/nir/nir_opcodes_c.py
> @@ -30,7 +30,8 @@ template = Template("""
>  #include "nir.h"
>
>  nir_op
> -nir_type_conversion_op(nir_alu_type src, nir_alu_type dst,
> nir_rounding_mode rnd)
> +nir_type_conversion_op(nir_alu_type src, nir_alu_type dst,
> nir_rounding_mode rnd,
> +                       bool saturate)
>  {
>     nir_alu_type src_base = (nir_alu_type) nir_alu_type_get_base_type(
> src);
>     nir_alu_type dst_base = (nir_alu_type) nir_alu_type_get_base_type(
> dst);
> @@ -41,7 +42,8 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type
> dst, nir_rounding_mode rnd
>        return nir_op_fmov;
>     } else if ((src_base == nir_type_int || src_base == nir_type_uint) &&
>                (dst_base == nir_type_int || dst_base == nir_type_uint) &&
> -              src_bit_size == dst_bit_size) {
> +              src_bit_size == dst_bit_size &&
> +              (src_base == dst_base || !saturate)) {
>        /* Integer <-> integer conversions with the same bit-size on both
>         * ends are just no-op moves.
>         */
> @@ -54,12 +56,9 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type
> dst, nir_rounding_mode rnd
>           switch (dst_base) {
>  %           for dst_t in ['int', 'uint', 'float']:
>              case nir_type_${dst_t}:
> +<%             orig_dst_t = dst_t %>
>  %              if src_t in ['int', 'uint'] and dst_t in ['int', 'uint']:
> -%                 if dst_t == 'int':
> -<%                   continue %>
> -%                 else:
> -<%                   dst_t = src_t %>
> -%                 endif
> +<%                dst_t = src_t %>
>  %              endif
>                 switch (dst_bit_size) {
>  %                 if dst_t == 'float':
> @@ -69,18 +68,25 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type
> dst, nir_rounding_mode rnd
>  %                 endif
>  %                 for dst_bits in bit_sizes:
>                    case ${dst_bits}:
> -%                    if src_t == 'float' and dst_t == 'float':
> +%                    if src_t == 'float' or dst_t == 'float':
>                       switch(rnd) {
> -%                       for rnd_t in [('rtne', '_rtne'), ('rtz', '_rtz'),
> ('undef', '')]:
> +%                       for rnd_t in [('rtne', '_rtne'), ('rtz', '_rtz'),
> ('ru', '_ru'), ('rd', '_rd'), ('undef', '')]:
>                          case nir_rounding_mode_${rnd_t[0]}:
> +%                          if dst_t != 'float':
> +                           if (saturate)
> +                              return ${'nir_op_{0}2{1}{2}{3}_sat'.format(src_t[0],
> dst_t[0],
> +
>  dst_bits, rnd_t[1])};
> +%                          endif
>                             return ${'nir_op_{0}2{1}{2}{3}'.format(src_t[0],
> dst_t[0],
>
> dst_bits, rnd_t[1])};
>  %                       endfor
>                          default:
> -                           unreachable("Invalid 16-bit nir rounding
> mode");
> +                           unreachable("Invalid float nir rounding mode");
>                       }
>  %                    else:
>                       assert(rnd == nir_rounding_mode_undef);
> +                     if (saturate)
> +                        return ${'nir_op_{0}2{1}{2}_sat'.format(src_t[0],
> orig_dst_t[0], dst_bits)};
>                       return ${'nir_op_{0}2{1}{2}'.format(src_t[0],
> dst_t[0], dst_bits)};
>  %                    endif
>  %                 endfor
> diff --git a/src/compiler/spirv/spirv_to_nir.c
> b/src/compiler/spirv/spirv_to_nir.c
> index 2a835f047e4..6f1a1871b38 100644
> --- a/src/compiler/spirv/spirv_to_nir.c
> +++ b/src/compiler/spirv/spirv_to_nir.c
> @@ -1726,7 +1726,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp
> opcode,
>              bit_size = glsl_get_bit_size(val->type->type);
>           };
>
> -         nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
> +         nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode,
> &swap,
>
> nir_alu_type_get_type_size(src_alu_type),
>
> nir_alu_type_get_type_size(dst_alu_type));
>           nir_const_value src[4];
> @@ -3839,6 +3839,8 @@ vtn_handle_body_instruction(struct vtn_builder *b,
> SpvOp opcode,
>     case SpvOpUConvert:
>     case SpvOpSConvert:
>     case SpvOpFConvert:
> +   case SpvOpSatConvertUToS:
> +   case SpvOpSatConvertSToU:
>     case SpvOpQuantizeToF16:
>     case SpvOpConvertPtrToU:
>     case SpvOpConvertUToPtr:
> diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c
> index 3134849ba90..b96f7d688fb 100644
> --- a/src/compiler/spirv/vtn_alu.c
> +++ b/src/compiler/spirv/vtn_alu.c
> @@ -273,8 +273,46 @@ vtn_handle_bitcast(struct vtn_builder *b, struct
> vtn_ssa_value *dest,
>     dest->def = nir_vec(&b->nb, dest_chan, dest_components);
>  }
>
> +static void
> +handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int
> member,
> +                     const struct vtn_decoration *dec, void
> *_out_rounding_mode)
> +{
> +   nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
> +   assert(dec->scope == VTN_DEC_DECORATION);
> +   if (dec->decoration != SpvDecorationFPRoundingMode)
> +      return;
> +   switch (dec->literals[0]) {
> +   case SpvFPRoundingModeRTE:
> +      *out_rounding_mode = nir_rounding_mode_rtne;
> +      break;
> +   case SpvFPRoundingModeRTZ:
> +      *out_rounding_mode = nir_rounding_mode_rtz;
> +      break;
> +   case SpvFPRoundingModeRTP:
> +      *out_rounding_mode = nir_rounding_mode_ru;
> +      break;
> +   case SpvFPRoundingModeRTN:
> +      *out_rounding_mode = nir_rounding_mode_rd;
> +      break;
> +   default:
> +      unreachable("Not supported rounding mode");
> +      break;
> +   }
> +}
> +
> +static void
> +handle_saturation(struct vtn_builder *b, struct vtn_value *val, int
> member,
> +                  const struct vtn_decoration *dec, void *_out_saturation)
> +{
> +   bool *out_saturation = _out_saturation;
> +   assert(dec->scope == VTN_DEC_DECORATION);
> +   if (dec->decoration != SpvDecorationSaturatedConversion)
> +      return;
> +   *out_saturation = true;
> +}
> +
>  nir_op
> -vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
> +vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, struct vtn_value
> *val,
>                                  SpvOp opcode, bool *swap,
>                                  unsigned src_bit_size, unsigned
> dst_bit_size)
>  {
> @@ -356,42 +394,67 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder
> *b,
>     case SpvOpConvertSToF:
>     case SpvOpConvertUToF:
>     case SpvOpSConvert:
> -   case SpvOpFConvert: {
> +   case SpvOpFConvert:
> +   case SpvOpSatConvertUToS:
> +   case SpvOpSatConvertSToU: {
>        nir_alu_type src_type;
>        nir_alu_type dst_type;
>
> +      nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
> +      bool saturation = false;
> +
>        switch (opcode) {
>        case SpvOpConvertFToS:
>           src_type = nir_type_float;
>           dst_type = nir_type_int;
> +         vtn_foreach_decoration(b, val, handle_rounding_mode,
> &rounding_mode);
> +         vtn_foreach_decoration(b, val, handle_saturation, &saturation);
>           break;
>        case SpvOpConvertFToU:
>           src_type = nir_type_float;
>           dst_type = nir_type_uint;
> +         vtn_foreach_decoration(b, val, handle_rounding_mode,
> &rounding_mode);
> +         vtn_foreach_decoration(b, val, handle_saturation, &saturation);
>           break;
>        case SpvOpFConvert:
>           src_type = dst_type = nir_type_float;
> +         vtn_foreach_decoration(b, val, handle_rounding_mode,
> &rounding_mode);
>           break;
>        case SpvOpConvertSToF:
>           src_type = nir_type_int;
>           dst_type = nir_type_float;
> +         vtn_foreach_decoration(b, val, handle_rounding_mode,
> &rounding_mode);
>           break;
>        case SpvOpSConvert:
>           src_type = dst_type = nir_type_int;
> +         vtn_foreach_decoration(b, val, handle_saturation, &saturation);
>           break;
>        case SpvOpConvertUToF:
>           src_type = nir_type_uint;
>           dst_type = nir_type_float;
> +         vtn_foreach_decoration(b, val, handle_rounding_mode,
> &rounding_mode);
>           break;
>        case SpvOpUConvert:
>           src_type = dst_type = nir_type_uint;
> +         vtn_foreach_decoration(b, val, handle_saturation, &saturation);
> +         break;
> +      case SpvOpSatConvertUToS:
> +         src_type = nir_type_uint;
> +         dst_type = nir_type_int;
> +         saturation = true;
> +         break;
> +      case SpvOpSatConvertSToU:
> +         src_type = nir_type_int;
> +         dst_type = nir_type_uint;
> +         saturation = true;
>           break;
>        default:
>           unreachable("Invalid opcode");
>        }
>        src_type |= src_bit_size;
>        dst_type |= dst_bit_size;
> -      return nir_type_conversion_op(src_type, dst_type,
> nir_rounding_mode_undef);
> +
> +      return nir_type_conversion_op(src_type, dst_type, rounding_mode,
> saturation);
>     }
>     /* Derivatives: */
>     case SpvOpDPdx:         return nir_op_fddx;
> @@ -417,27 +480,6 @@ handle_no_contraction(struct vtn_builder *b, struct
> vtn_value *val, int member,
>     b->nb.exact = true;
>  }
>
> -static void
> -handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int
> member,
> -                     const struct vtn_decoration *dec, void
> *_out_rounding_mode)
> -{
> -   nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
> -   assert(dec->scope == VTN_DEC_DECORATION);
> -   if (dec->decoration != SpvDecorationFPRoundingMode)
> -      return;
> -   switch (dec->literals[0]) {
> -   case SpvFPRoundingModeRTE:
> -      *out_rounding_mode = nir_rounding_mode_rtne;
> -      break;
> -   case SpvFPRoundingModeRTZ:
> -      *out_rounding_mode = nir_rounding_mode_rtz;
> -      break;
> -   default:
> -      unreachable("Not supported rounding mode");
> -      break;
> -   }
> -}
> -
>  void
>  vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
>                 const uint32_t *w, unsigned count)
> @@ -579,7 +621,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
>        bool swap;
>        unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
>        unsigned dst_bit_size = glsl_get_bit_size(type);
> -      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
> +      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap,
>                                                    src_bit_size,
> dst_bit_size);
>
>        if (swap) {
> @@ -605,7 +647,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
>        bool swap;
>        unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
>        unsigned dst_bit_size = glsl_get_bit_size(type);
> -      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
> +      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap,
>                                                    src_bit_size,
> dst_bit_size);
>
>        assert(!swap);
> @@ -623,23 +665,11 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
>        vtn_handle_bitcast(b, val->ssa, src[0]);
>        break;
>
> -   case SpvOpFConvert: {
> -      nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_
> type(vtn_src[0]->type);
> -      nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
> -      nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
> -
> -      vtn_foreach_decoration(b, val, handle_rounding_mode,
> &rounding_mode);
> -      nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type,
> rounding_mode);
> -
> -      val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL,
> NULL);
> -      break;
> -   }
> -
>     default: {
>        bool swap;
>        unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
>        unsigned dst_bit_size = glsl_get_bit_size(type);
> -      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
> +      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap,
>                                                    src_bit_size,
> dst_bit_size);
>
>        if (swap) {
> diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_
> glsl450.c
> index 6fa759b1bba..284371446b5 100644
> --- a/src/compiler/spirv/vtn_glsl450.c
> +++ b/src/compiler/spirv/vtn_glsl450.c
> @@ -659,7 +659,7 @@ handle_glsl450_alu(struct vtn_builder *b, enum
> GLSLstd450 entrypoint,
>           nir_op conversion_op =
>              nir_type_conversion_op(nir_type_float | eta->bit_size,
>                                     nir_type_float | I->bit_size,
> -                                   nir_rounding_mode_undef);
> +                                   nir_rounding_mode_undef, false);
>           eta = nir_build_alu(nb, conversion_op, eta, NULL, NULL, NULL);
>        }
>        /* k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I)) */
> diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_
> private.h
> index b501bbf9b4a..0895c865fbb 100644
> --- a/src/compiler/spirv/vtn_private.h
> +++ b/src/compiler/spirv/vtn_private.h
> @@ -708,7 +708,7 @@ typedef void (*vtn_execution_mode_foreach_cb)(struct
> vtn_builder *,
>  void vtn_foreach_execution_mode(struct vtn_builder *b, struct vtn_value
> *value,
>                                  vtn_execution_mode_foreach_cb cb, void
> *data);
>
> -nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
> +nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, struct
> vtn_value *val,
>                                         SpvOp opcode, bool *swap,
>                                         unsigned src_bit_size, unsigned
> dst_bit_size);
>
> --
> 2.14.3
>
> _______________________________________________
> mesa-dev mailing list
> mesa-dev at lists.freedesktop.org
> https://lists.freedesktop.org/mailman/listinfo/mesa-dev
>
-------------- next part --------------
An HTML attachment was scrubbed...
URL: <https://lists.freedesktop.org/archives/mesa-dev/attachments/20180428/e74fc0df/attachment-0001.html>


More information about the mesa-dev mailing list