[Mesa-dev] [PATCH 3/4] nir: add all combinations of conversions with rounding and saturation

Karol Herbst kherbst at redhat.com
Sat Apr 28 11:14:14 UTC 2018


OpenCL has explicit casts where one can specify the rounding mode and put a
sat modifier:

https://www.khronos.org/registry/OpenCL/sdk/2.1/docs/man/xhtml/convert_T.html

_sat is valid for all conversions to an integer type and rounding modes are
valid for all conversions involving floats.

Allthough the FPRoundingMode modifier is allowed without any restrictions in
capabilities, it can only be used together with fp16 in GLSL. Additionally it
can be used for conversions to/from floating points in OpenCL.

The SaturatedConversion modifier, OpSatConvertUToS and OpSatConvertSToU are
only supported for Kernels, so current drivers are safe.

Signed-off-by: Karol Herbst <kherbst at redhat.com>
---
 src/compiler/glsl/glsl_to_nir.cpp |   2 +-
 src/compiler/nir/nir.h            |   2 +-
 src/compiler/nir/nir_opcodes.py   |  28 +++++-----
 src/compiler/nir/nir_opcodes_c.py |  26 +++++----
 src/compiler/spirv/spirv_to_nir.c |   4 +-
 src/compiler/spirv/vtn_alu.c      | 108 ++++++++++++++++++++++++--------------
 src/compiler/spirv/vtn_glsl450.c  |   2 +-
 src/compiler/spirv/vtn_private.h  |   2 +-
 8 files changed, 107 insertions(+), 67 deletions(-)

diff --git a/src/compiler/glsl/glsl_to_nir.cpp b/src/compiler/glsl/glsl_to_nir.cpp
index 8e5e9c34912..fcb6ef27e47 100644
--- a/src/compiler/glsl/glsl_to_nir.cpp
+++ b/src/compiler/glsl/glsl_to_nir.cpp
@@ -1589,7 +1589,7 @@ nir_visitor::visit(ir_expression *ir)
       nir_alu_type src_type = nir_get_nir_type_for_glsl_base_type(types[0]);
       nir_alu_type dst_type = nir_get_nir_type_for_glsl_base_type(out_type);
       result = nir_build_alu(&b, nir_type_conversion_op(src_type, dst_type,
-                                 nir_rounding_mode_undef),
+                                 nir_rounding_mode_undef, false),
                                  srcs[0], NULL, NULL, NULL);
       /* b2i and b2f don't have fixed bit-size versions so the builder will
        * just assume 32 and we have to fix it up here.
diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h
index f3326e6df94..f32e5bd8bb2 100644
--- a/src/compiler/nir/nir.h
+++ b/src/compiler/nir/nir.h
@@ -784,7 +784,7 @@ nir_get_nir_type_for_glsl_type(const struct glsl_type *type)
 }
 
 nir_op nir_type_conversion_op(nir_alu_type src, nir_alu_type dst,
-                              nir_rounding_mode rnd);
+                              nir_rounding_mode rnd, bool saturation);
 
 typedef enum {
    NIR_OP_IS_COMMUTATIVE = (1 << 0),
diff --git a/src/compiler/nir/nir_opcodes.py b/src/compiler/nir/nir_opcodes.py
index f4cd175bc6a..9c51f77bf1b 100644
--- a/src/compiler/nir/nir_opcodes.py
+++ b/src/compiler/nir/nir_opcodes.py
@@ -168,26 +168,28 @@ unop("flog2", tfloat, "log2f(src0)")
 
 # Generate all of the numeric conversion opcodes
 for src_t in [tint, tuint, tfloat]:
-   if src_t in (tint, tuint):
-      dst_types = [tfloat, src_t]
-   elif src_t == tfloat:
-      dst_types = [tint, tuint, tfloat]
-
-   for dst_t in dst_types:
+   for dst_t in [tint, tuint, tfloat]:
       if dst_t == tfloat:
          bit_sizes = [16, 32, 64]
+         sat_modes = ['']
       else:
          bit_sizes = [8, 16, 32, 64]
+         if src_t != tfloat and dst_t != src_t:
+            sat_modes = ['_sat']
+         else:
+            sat_modes = ['_sat', '']
       for bit_size in bit_sizes:
-          if dst_t == tfloat and src_t == tfloat:
-              rnd_modes = ['_rtne', '_rtz', '']
-              for rnd_mode in rnd_modes:
+          for sat_mode in sat_modes:
+              if src_t == tfloat or dst_t == tfloat:
+                  for rnd_mode in ['_rtne', '_rtz', '_ru', '_rd', '']:
+                      unop_convert("{0}2{1}{2}{3}{4}".format(src_t[0], dst_t[0],
+                                                             bit_size, rnd_mode,
+                                                             sat_mode),
+                                   dst_t + str(bit_size), src_t, "src0")
+              else:
                   unop_convert("{0}2{1}{2}{3}".format(src_t[0], dst_t[0],
-                                                       bit_size, rnd_mode),
+                                                      bit_size, sat_mode),
                                dst_t + str(bit_size), src_t, "src0")
-          else:
-              unop_convert("{0}2{1}{2}".format(src_t[0], dst_t[0], bit_size),
-                           dst_t + str(bit_size), src_t, "src0")
 
 # We'll hand-code the to/from bool conversion opcodes.  Because bool doesn't
 # have multiple bit-sizes, we can always infer the size from the other type.
diff --git a/src/compiler/nir/nir_opcodes_c.py b/src/compiler/nir/nir_opcodes_c.py
index 19079f86e7b..9b8642f0cc1 100644
--- a/src/compiler/nir/nir_opcodes_c.py
+++ b/src/compiler/nir/nir_opcodes_c.py
@@ -30,7 +30,8 @@ template = Template("""
 #include "nir.h"
 
 nir_op
-nir_type_conversion_op(nir_alu_type src, nir_alu_type dst, nir_rounding_mode rnd)
+nir_type_conversion_op(nir_alu_type src, nir_alu_type dst, nir_rounding_mode rnd,
+                       bool saturate)
 {
    nir_alu_type src_base = (nir_alu_type) nir_alu_type_get_base_type(src);
    nir_alu_type dst_base = (nir_alu_type) nir_alu_type_get_base_type(dst);
@@ -41,7 +42,8 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type dst, nir_rounding_mode rnd
       return nir_op_fmov;
    } else if ((src_base == nir_type_int || src_base == nir_type_uint) &&
               (dst_base == nir_type_int || dst_base == nir_type_uint) &&
-              src_bit_size == dst_bit_size) {
+              src_bit_size == dst_bit_size &&
+              (src_base == dst_base || !saturate)) {
       /* Integer <-> integer conversions with the same bit-size on both
        * ends are just no-op moves.
        */
@@ -54,12 +56,9 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type dst, nir_rounding_mode rnd
          switch (dst_base) {
 %           for dst_t in ['int', 'uint', 'float']:
             case nir_type_${dst_t}:
+<%             orig_dst_t = dst_t %>
 %              if src_t in ['int', 'uint'] and dst_t in ['int', 'uint']:
-%                 if dst_t == 'int':
-<%                   continue %>
-%                 else:
-<%                   dst_t = src_t %>
-%                 endif
+<%                dst_t = src_t %>
 %              endif
                switch (dst_bit_size) {
 %                 if dst_t == 'float':
@@ -69,18 +68,25 @@ nir_type_conversion_op(nir_alu_type src, nir_alu_type dst, nir_rounding_mode rnd
 %                 endif
 %                 for dst_bits in bit_sizes:
                   case ${dst_bits}:
-%                    if src_t == 'float' and dst_t == 'float':
+%                    if src_t == 'float' or dst_t == 'float':
                      switch(rnd) {
-%                       for rnd_t in [('rtne', '_rtne'), ('rtz', '_rtz'), ('undef', '')]:
+%                       for rnd_t in [('rtne', '_rtne'), ('rtz', '_rtz'), ('ru', '_ru'), ('rd', '_rd'), ('undef', '')]:
                         case nir_rounding_mode_${rnd_t[0]}:
+%                          if dst_t != 'float':
+                           if (saturate)
+                              return ${'nir_op_{0}2{1}{2}{3}_sat'.format(src_t[0], dst_t[0],
+                                                                         dst_bits, rnd_t[1])};
+%                          endif
                            return ${'nir_op_{0}2{1}{2}{3}'.format(src_t[0], dst_t[0],
                                                                    dst_bits, rnd_t[1])};
 %                       endfor
                         default:
-                           unreachable("Invalid 16-bit nir rounding mode");
+                           unreachable("Invalid float nir rounding mode");
                      }
 %                    else:
                      assert(rnd == nir_rounding_mode_undef);
+                     if (saturate)
+                        return ${'nir_op_{0}2{1}{2}_sat'.format(src_t[0], orig_dst_t[0], dst_bits)};
                      return ${'nir_op_{0}2{1}{2}'.format(src_t[0], dst_t[0], dst_bits)};
 %                    endif
 %                 endfor
diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c
index 2a835f047e4..6f1a1871b38 100644
--- a/src/compiler/spirv/spirv_to_nir.c
+++ b/src/compiler/spirv/spirv_to_nir.c
@@ -1726,7 +1726,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
             bit_size = glsl_get_bit_size(val->type->type);
          };
 
-         nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
+         nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap,
                                                      nir_alu_type_get_type_size(src_alu_type),
                                                      nir_alu_type_get_type_size(dst_alu_type));
          nir_const_value src[4];
@@ -3839,6 +3839,8 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
    case SpvOpUConvert:
    case SpvOpSConvert:
    case SpvOpFConvert:
+   case SpvOpSatConvertUToS:
+   case SpvOpSatConvertSToU:
    case SpvOpQuantizeToF16:
    case SpvOpConvertPtrToU:
    case SpvOpConvertUToPtr:
diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c
index 3134849ba90..b96f7d688fb 100644
--- a/src/compiler/spirv/vtn_alu.c
+++ b/src/compiler/spirv/vtn_alu.c
@@ -273,8 +273,46 @@ vtn_handle_bitcast(struct vtn_builder *b, struct vtn_ssa_value *dest,
    dest->def = nir_vec(&b->nb, dest_chan, dest_components);
 }
 
+static void
+handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member,
+                     const struct vtn_decoration *dec, void *_out_rounding_mode)
+{
+   nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
+   assert(dec->scope == VTN_DEC_DECORATION);
+   if (dec->decoration != SpvDecorationFPRoundingMode)
+      return;
+   switch (dec->literals[0]) {
+   case SpvFPRoundingModeRTE:
+      *out_rounding_mode = nir_rounding_mode_rtne;
+      break;
+   case SpvFPRoundingModeRTZ:
+      *out_rounding_mode = nir_rounding_mode_rtz;
+      break;
+   case SpvFPRoundingModeRTP:
+      *out_rounding_mode = nir_rounding_mode_ru;
+      break;
+   case SpvFPRoundingModeRTN:
+      *out_rounding_mode = nir_rounding_mode_rd;
+      break;
+   default:
+      unreachable("Not supported rounding mode");
+      break;
+   }
+}
+
+static void
+handle_saturation(struct vtn_builder *b, struct vtn_value *val, int member,
+                  const struct vtn_decoration *dec, void *_out_saturation)
+{
+   bool *out_saturation = _out_saturation;
+   assert(dec->scope == VTN_DEC_DECORATION);
+   if (dec->decoration != SpvDecorationSaturatedConversion)
+      return;
+   *out_saturation = true;
+}
+
 nir_op
-vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
+vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, struct vtn_value *val,
                                 SpvOp opcode, bool *swap,
                                 unsigned src_bit_size, unsigned dst_bit_size)
 {
@@ -356,42 +394,67 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
    case SpvOpConvertSToF:
    case SpvOpConvertUToF:
    case SpvOpSConvert:
-   case SpvOpFConvert: {
+   case SpvOpFConvert:
+   case SpvOpSatConvertUToS:
+   case SpvOpSatConvertSToU: {
       nir_alu_type src_type;
       nir_alu_type dst_type;
 
+      nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
+      bool saturation = false;
+
       switch (opcode) {
       case SpvOpConvertFToS:
          src_type = nir_type_float;
          dst_type = nir_type_int;
+         vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
+         vtn_foreach_decoration(b, val, handle_saturation, &saturation);
          break;
       case SpvOpConvertFToU:
          src_type = nir_type_float;
          dst_type = nir_type_uint;
+         vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
+         vtn_foreach_decoration(b, val, handle_saturation, &saturation);
          break;
       case SpvOpFConvert:
          src_type = dst_type = nir_type_float;
+         vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
          break;
       case SpvOpConvertSToF:
          src_type = nir_type_int;
          dst_type = nir_type_float;
+         vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
          break;
       case SpvOpSConvert:
          src_type = dst_type = nir_type_int;
+         vtn_foreach_decoration(b, val, handle_saturation, &saturation);
          break;
       case SpvOpConvertUToF:
          src_type = nir_type_uint;
          dst_type = nir_type_float;
+         vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
          break;
       case SpvOpUConvert:
          src_type = dst_type = nir_type_uint;
+         vtn_foreach_decoration(b, val, handle_saturation, &saturation);
+         break;
+      case SpvOpSatConvertUToS:
+         src_type = nir_type_uint;
+         dst_type = nir_type_int;
+         saturation = true;
+         break;
+      case SpvOpSatConvertSToU:
+         src_type = nir_type_int;
+         dst_type = nir_type_uint;
+         saturation = true;
          break;
       default:
          unreachable("Invalid opcode");
       }
       src_type |= src_bit_size;
       dst_type |= dst_bit_size;
-      return nir_type_conversion_op(src_type, dst_type, nir_rounding_mode_undef);
+
+      return nir_type_conversion_op(src_type, dst_type, rounding_mode, saturation);
    }
    /* Derivatives: */
    case SpvOpDPdx:         return nir_op_fddx;
@@ -417,27 +480,6 @@ handle_no_contraction(struct vtn_builder *b, struct vtn_value *val, int member,
    b->nb.exact = true;
 }
 
-static void
-handle_rounding_mode(struct vtn_builder *b, struct vtn_value *val, int member,
-                     const struct vtn_decoration *dec, void *_out_rounding_mode)
-{
-   nir_rounding_mode *out_rounding_mode = _out_rounding_mode;
-   assert(dec->scope == VTN_DEC_DECORATION);
-   if (dec->decoration != SpvDecorationFPRoundingMode)
-      return;
-   switch (dec->literals[0]) {
-   case SpvFPRoundingModeRTE:
-      *out_rounding_mode = nir_rounding_mode_rtne;
-      break;
-   case SpvFPRoundingModeRTZ:
-      *out_rounding_mode = nir_rounding_mode_rtz;
-      break;
-   default:
-      unreachable("Not supported rounding mode");
-      break;
-   }
-}
-
 void
 vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
                const uint32_t *w, unsigned count)
@@ -579,7 +621,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       bool swap;
       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
       unsigned dst_bit_size = glsl_get_bit_size(type);
-      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
+      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap,
                                                   src_bit_size, dst_bit_size);
 
       if (swap) {
@@ -605,7 +647,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       bool swap;
       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
       unsigned dst_bit_size = glsl_get_bit_size(type);
-      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
+      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap,
                                                   src_bit_size, dst_bit_size);
 
       assert(!swap);
@@ -623,23 +665,11 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
       vtn_handle_bitcast(b, val->ssa, src[0]);
       break;
 
-   case SpvOpFConvert: {
-      nir_alu_type src_alu_type = nir_get_nir_type_for_glsl_type(vtn_src[0]->type);
-      nir_alu_type dst_alu_type = nir_get_nir_type_for_glsl_type(type);
-      nir_rounding_mode rounding_mode = nir_rounding_mode_undef;
-
-      vtn_foreach_decoration(b, val, handle_rounding_mode, &rounding_mode);
-      nir_op op = nir_type_conversion_op(src_alu_type, dst_alu_type, rounding_mode);
-
-      val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], NULL, NULL);
-      break;
-   }
-
    default: {
       bool swap;
       unsigned src_bit_size = glsl_get_bit_size(vtn_src[0]->type);
       unsigned dst_bit_size = glsl_get_bit_size(type);
-      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap,
+      nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, val, opcode, &swap,
                                                   src_bit_size, dst_bit_size);
 
       if (swap) {
diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_glsl450.c
index 6fa759b1bba..284371446b5 100644
--- a/src/compiler/spirv/vtn_glsl450.c
+++ b/src/compiler/spirv/vtn_glsl450.c
@@ -659,7 +659,7 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint,
          nir_op conversion_op =
             nir_type_conversion_op(nir_type_float | eta->bit_size,
                                    nir_type_float | I->bit_size,
-                                   nir_rounding_mode_undef);
+                                   nir_rounding_mode_undef, false);
          eta = nir_build_alu(nb, conversion_op, eta, NULL, NULL, NULL);
       }
       /* k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I)) */
diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h
index b501bbf9b4a..0895c865fbb 100644
--- a/src/compiler/spirv/vtn_private.h
+++ b/src/compiler/spirv/vtn_private.h
@@ -708,7 +708,7 @@ typedef void (*vtn_execution_mode_foreach_cb)(struct vtn_builder *,
 void vtn_foreach_execution_mode(struct vtn_builder *b, struct vtn_value *value,
                                 vtn_execution_mode_foreach_cb cb, void *data);
 
-nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
+nir_op vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, struct vtn_value *val,
                                        SpvOp opcode, bool *swap,
                                        unsigned src_bit_size, unsigned dst_bit_size);
 
-- 
2.14.3



More information about the mesa-dev mailing list